Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a3be70a

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Correct cyclic lr scheme, docs, play with AE.
PiperOrigin-RevId: 166920305
1 parent b54b711 commit a3be70a

File tree

4 files changed

+100
-76
lines changed

4 files changed

+100
-76
lines changed

docs/example_life.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,21 @@ and how all its parts are connected to work together.
1414

1515
## The Life of an Example
1616

17-
TODO: complete.
17+
A training example passes the following stages in T2T:
18+
* raw input (text from command line or file)
19+
* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
20+
* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
21+
* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
22+
* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
23+
* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
24+
* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.
25+
26+
We go into these phases step by step below.
27+
28+
## Feature Encoders
29+
30+
TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.
31+
32+
## Modalities
33+
34+
TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ documentation, from basic tutorials to full code documentation.
2525
## Deep Dive
2626

2727
* [Life of an Example](example_life.md): how all parts of T2T are connected and work together
28+
* [Distributed Training](distributed_training.md)
2829

2930
## Code documentation
3031

tensor2tensor/models/transformer_vae.py

Lines changed: 80 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""VAE Transformer."""
16+
"""AE Transformer."""
1717

1818
from __future__ import absolute_import
1919
from __future__ import division
@@ -32,10 +32,9 @@
3232
import tensorflow as tf
3333

3434

35-
def residual_conv(x, repeat, hparams, name, reuse=None):
35+
def residual_conv(x, repeat, k, hparams, name, reuse=None):
3636
"""A stack of convolution blocks with residual connections."""
3737
with tf.variable_scope(name, reuse=reuse):
38-
k = (3, 1)
3938
dilations_and_kernels = [((1, 1), k) for _ in xrange(3)]
4039
for i in xrange(repeat):
4140
with tf.variable_scope("repeat_%d" % i):
@@ -72,15 +71,19 @@ def interleave(x, y, axis=1):
7271
return tf.concat([x, y], axis=axis+1)
7372

7473

75-
def decompress_step(source, c, hparams, first_relu, name):
74+
def decompress_step(source, c, hparams, first_relu, is_2d, name):
7675
"""Decompression function."""
7776
with tf.variable_scope(name):
7877
shape = tf.shape(source)
7978
if c is not None:
8079
source = attend(source, c, hparams, "decompress_attend")
80+
multiplier = 4 if is_2d else 2
81+
kernel = (1, 1) if is_2d else (1, 1)
8182
thicker = common_layers.conv_block(
82-
source, hparams.hidden_size * 2, [((1, 1), (1, 1))],
83+
source, hparams.hidden_size * multiplier, [((1, 1), kernel)],
8384
first_relu=first_relu, name="decompress_conv")
85+
if is_2d:
86+
return tf.depth_to_space(thicker, 2)
8487
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
8588

8689

@@ -90,7 +93,7 @@ def gumbel_sample(shape):
9093
return -tf.log(-tf.log(uniform_samples))
9194

9295

93-
def dvae(x, hparams, name):
96+
def dae(x, hparams, name):
9497
with tf.variable_scope(name):
9598
m = tf.layers.dense(x, hparams.v_size, name="mask")
9699
logsm = tf.nn.log_softmax(m)
@@ -128,7 +131,7 @@ def nearest(x, means, hparams):
128131
_, nearest_idx = tf.nn.top_k(- dist, k=1)
129132
nearest_hot = tf.one_hot(tf.squeeze(nearest_idx, axis=1), hparams.v_size)
130133
nearest_hot = tf.reshape(nearest_hot, [tf.shape(x)[0], tf.shape(x)[1],
131-
1, hparams.v_size])
134+
tf.shape(x)[2], hparams.v_size])
132135
return tf.stop_gradient(nearest_hot)
133136

134137

@@ -137,21 +140,23 @@ def kmeans(x, means, hparams, name):
137140
x_means_hot = nearest(x, means, hparams)
138141
x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1))
139142
kl = tf.reduce_sum(tf.square(x - x_means), axis=-1)
140-
return x_means_hot, x_means_hot, tf.reduce_mean(kl) * 10.0
143+
return x_means_hot, tf.reduce_mean(kl) * 10.0
141144

142145

143-
def compress(x, c, hparams, name):
146+
def compress(x, c, is_2d, hparams, name):
144147
"""Compress."""
145148
with tf.variable_scope(name):
146149
# Run compression by strided convs.
147150
cur = x
151+
k1 = (3, 3) if is_2d else (3, 1)
152+
k2 = (2, 2) if is_2d else (2, 1)
148153
for i in xrange(hparams.num_compress_steps):
149154
if c is not None:
150155
cur = attend(cur, c, hparams, "compress_attend_%d" % i)
151-
cur = residual_conv(cur, 1, hparams, "compress_rc_%d" % i)
156+
cur = residual_conv(cur, 1, k1, hparams, "compress_rc_%d" % i)
152157
cur = common_layers.conv_block(
153-
cur, hparams.hidden_size, [((1, 1), (2, 1))],
154-
strides=(2, 1), name="compress_%d" % i)
158+
cur, hparams.hidden_size, [((1, 1), k2)],
159+
strides=k2, name="compress_%d" % i)
155160
return cur
156161

157162

@@ -188,7 +193,7 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams):
188193
decoder_input = tf.squeeze(decoder_input, axis=2)
189194
decoder_input = common_attention.add_timing_signal_1d(decoder_input)
190195
bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1])
191-
if c is not None:
196+
if c is not None and len(c.get_shape()) > 3:
192197
c = tf.squeeze(c, axis=2)
193198
return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
194199

@@ -205,69 +210,62 @@ def expand_batch(x, mul):
205210
return tf.reshape(cx, res_shape)
206211

207212

208-
def vae_compress(x, c, ed, hparams, compress_name, decompress_name, reuse=None):
209-
"""Compress, then VAE."""
210-
with tf.variable_scope(compress_name, reuse=reuse):
211-
cur = compress(x, None, hparams, "compress")
213+
def ae_compress(x, is_2d, hparams, name, reuse=None):
214+
"""Compress, then AE."""
215+
with tf.variable_scope(name, reuse=reuse):
216+
cur = compress(x, None, is_2d, hparams, "compress")
212217
# Convolve and ReLu to get state.
213218
cur = common_layers.conv_block(
214219
cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
215220
cur = tf.nn.l2_normalize(cur, dim=3)
216221
cur_n = hparams.kmeans_lr_factor * cur
217222
cur_n += (1.0 - hparams.kmeans_lr_factor) * tf.stop_gradient(cur)
218223
means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size])
219-
# z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
220-
# z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae")
221-
z_true, z_sample, kl_loss = kmeans(cur_n, means, hparams, name="kmeans")
222-
223-
# Compress context.
224-
with tf.variable_scope(compress_name, reuse=reuse):
225-
compress_c = compress(c, None, hparams, "compress_context")
226-
dec_c = decode(None, compress_c, cur, None, None, hparams)
227-
c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
228-
reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
229-
labels=z_true, logits=c_z)
224+
hot, loss = kmeans(cur_n, means, hparams, name="kmeans")
225+
# We need a linear layer to undo the l2-normalization.
226+
cur = tf.layers.dense(cur, hparams.hidden_size, name="unnormalize")
227+
return cur, hot, loss
230228

231-
# If not training, use the predicted z instead of the autoregressive one.
232-
if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
233-
z = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)
234229

235-
with tf.variable_scope(decompress_name, reuse=reuse):
236-
# Decompress.
237-
z_sample_flat = tf.reshape(z_sample, [-1, hparams.v_size])
238-
z = tf.matmul(z_sample_flat, means)
239-
z = tf.reshape(z, [tf.shape(z_sample)[0], tf.shape(z_sample)[1],
240-
1, hparams.hidden_size])
230+
def ae_embed(hot, hparams, name, reuse=None):
231+
with tf.variable_scope(name, reuse=reuse):
232+
means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size])
233+
hot_flat = tf.reshape(hot, [-1, hparams.v_size])
234+
emb = tf.matmul(hot_flat, means)
235+
emb = tf.reshape(emb, [tf.shape(hot)[0], tf.shape(hot)[1],
236+
tf.shape(hot)[2], hparams.hidden_size])
237+
return tf.layers.dense(emb, hparams.hidden_size,
238+
name="unnormalize", reuse=reuse)
239+
241240

241+
def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None):
242+
"""Decompress from z, leaking from ae."""
243+
with tf.variable_scope(name + "_decompress", reuse=reuse):
242244
# Leak at the beginning to help train.
243-
z = mix(z, cur, hparams.startup_steps)
245+
z = mix(z, ae, hparams.startup_steps)
244246
prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8
245-
prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
247+
prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0
246248
z = tf.cond(tf.less(tf.random_uniform([]), prob_z),
247-
lambda: z, lambda: cur)
248-
z = tf.layers.dense(z, hparams.hidden_size, name="unnormalize")
249+
lambda: z, lambda: ae)
249250

250251
# Dropout for better autoencoding.
251-
z = tf.nn.dropout(z, keep_prob=0.9)
252+
z = tf.nn.dropout(z, keep_prob=1.0 - hparams.z_dropout)
252253

253254
# Decompress.
254255
d = z
255256
for i in xrange(hparams.num_compress_steps):
256257
j = hparams.num_compress_steps - i - 1
257-
d = residual_conv(d, 1, hparams, "decompress_rc_%d" % j)
258-
d = decompress_step(d, c, hparams, i > 0, "decompress_step_%d" % j)
258+
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
259+
d = decompress_step(d, None, hparams, i > 0, is_2d, "decompress_%d" % j)
259260

260261
k = 2**hparams.num_compress_steps
261262
z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size])
262263
x_batch = tf.reshape(x, [-1, k, 1, hparams.hidden_size])
263264
d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size])
264-
# dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
265-
c = expand_batch(c, tf.shape(x_batch)[0] / tf.shape(x)[0])
266-
ed = expand_batch(ed, tf.shape(x_batch)[0] / tf.shape(x)[0])
267-
dec_batch = decode(z_batch, d_batch, x_batch, c, ed, hparams)
265+
dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
268266
z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], 1, hparams.hidden_size])
269267

270-
return z, kl_loss, reconstruct_loss
268+
return z
271269

272270

273271
def ffn(x, hparams, name):
@@ -277,35 +275,42 @@ def ffn(x, hparams, name):
277275
return common_layers.layer_postprocess(x, y, hparams)
278276

279277

280-
def vae_transformer_internal(inputs, targets, target_space, hparams):
281-
"""VAE Transformer, main step used for training."""
282-
with tf.variable_scope("vae_transformer"):
283-
# Prepare inputs, targets, and k.
284-
inputs = common_layers.flatten4d3d(inputs)
285-
input_len = tf.shape(inputs)[1] # Double input size to cover targets.
286-
inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
287-
inputs.set_shape([None, None, hparams.hidden_size])
288-
targets = common_layers.flatten4d3d(targets)
278+
def ae_transformer_internal(inputs, targets, target_space, hparams):
279+
"""AE Transformer, main step used for training."""
280+
with tf.variable_scope("ae_transformer"):
281+
# Prepare inputs, targets, k.
289282
k = 2**hparams.num_compress_steps
290-
inputs, targets = common_layers.pad_to_same_length(
291-
inputs, targets, final_length_divisible_by=k)
292-
inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc")
293-
294-
# Compress and vae.
295-
z, kl, r = vae_compress(tf.expand_dims(targets, axis=2),
296-
tf.expand_dims(inputs, axis=2),
297-
ed_bias, hparams, "vae_compress", "vae_decompress")
283+
_, targets = common_layers.pad_to_same_length(
284+
targets, targets, final_length_divisible_by=k)
285+
inputs = common_layers.flatten4d3d(inputs)
286+
inputs, ed = encode(inputs, target_space, hparams, "input_enc")
287+
288+
# Compress and ae.
289+
ae, hot, kl = ae_compress(targets, False, hparams, "ae")
290+
emb = ae_embed(hot, hparams, "ae", reuse=True)
291+
292+
# Compress context and run autoregressive decoder on emb-hot.
293+
dec_c = decode(None, None, emb, inputs, ed, hparams)
294+
c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context")
295+
reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
296+
labels=hot, logits=c_z)
297+
# If not training, use the predicted z instead of the autoregressive one.
298+
if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
299+
hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size)
300+
301+
# Decompress, pass for ae loss.
302+
z = ae_decompress(emb, ae, targets, False, hparams, "ae")
298303
kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5))
299-
r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5))
300-
losses = {"kl": kl, "reconstruction": r}
304+
reconstruct_loss *= common_layers.inverse_exp_decay(hparams.startup_steps)
305+
losses = {"kl": kl, "reconstruction": reconstruct_loss}
301306
return z, losses
302307

303308

304309
@registry.register_model
305-
class TransformerVAE(t2t_model.T2TModel):
310+
class TransformerAE(t2t_model.T2TModel):
306311

307312
def model_fn_body(self, features):
308-
return vae_transformer_internal(
313+
return ae_transformer_internal(
309314
features["inputs"], features["targets"], features["target_space_id"],
310315
self._hparams)
311316

@@ -348,7 +353,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
348353

349354

350355
@registry.register_hparams
351-
def transformer_vae_small():
356+
def transformer_ae_small():
352357
"""Set of hyperparameters."""
353358
hparams = transformer.transformer_small()
354359
hparams.batch_size = 2048
@@ -358,19 +363,20 @@ def transformer_vae_small():
358363
hparams.add_hparam("num_compress_steps", 4)
359364
hparams.add_hparam("kl_warmup_steps", 60000)
360365
hparams.add_hparam("startup_steps", 30000)
366+
hparams.add_hparam("kmeans_lr_factor", 0.002)
367+
hparams.add_hparam("z_dropout", 0.1)
361368
return hparams
362369

363370

364371
@registry.register_hparams
365-
def transformer_vae_base():
372+
def transformer_ae_base():
366373
"""Set of hyperparameters."""
367-
hparams = transformer_vae_small()
374+
hparams = transformer_ae_small()
368375
hparams.hidden_size = 512
369376
hparams.filter_size = 2048
370377
hparams.attention_dropout = 0.0
371378
hparams.relu_dropout = 0.0
372379
hparams.dropout = 0.0
373380
hparams.num_hidden_layers = 4
374-
hparams.kmeans_lr_factor = 0.002
375381
hparams.z_size = 256
376382
return hparams

tensor2tensor/utils/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def learning_rate_decay():
111111
cycle_position = tf.to_float( # Normalize to the interval [-1, 1].
112112
cycle_position - cycle_steps) / float(cycle_steps)
113113
cycle_position = 1.0 - tf.abs(cycle_position) # 0 to 1 and back to 0.
114-
return (cycle_position + 0.01) * 10.0 # 10x difference each cycle.
114+
return (cycle_position + 0.1) * 3.0 # 10x difference each cycle (0.3-3).
115115

116116
inv_base = tf.exp(tf.log(0.01) / warmup_steps)
117117
inv_decay = inv_base**(warmup_steps - step)

0 commit comments

Comments
 (0)