|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""VAE Transformer.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +# Dependency imports |
| 23 | + |
| 24 | +from six.moves import xrange # pylint: disable=redefined-builtin |
| 25 | + |
| 26 | +from tensor2tensor.layers import common_layers |
| 27 | +from tensor2tensor.models import transformer |
| 28 | +from tensor2tensor.utils import registry |
| 29 | +from tensor2tensor.utils import t2t_model |
| 30 | + |
| 31 | +import tensorflow as tf |
| 32 | + |
| 33 | + |
| 34 | +def decompress(source, hparams, name): |
| 35 | + """Decompression function.""" |
| 36 | + with tf.variable_scope(name): |
| 37 | + shape = tf.shape(source) |
| 38 | + thicker = common_layers.conv_block( |
| 39 | + source, hparams.hidden_size * 2, [((1, 1), (1, 1))], |
| 40 | + name="decompress_conv") |
| 41 | + return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size]) |
| 42 | + |
| 43 | + |
| 44 | +def vae(x, hparams, name): |
| 45 | + with tf.variable_scope(name): |
| 46 | + mu = tf.layers.dense(x, hparams.z_size, name="mu") |
| 47 | + log_sigma = tf.layers.dense(x, hparams.z_size, name="log_sigma") |
| 48 | + shape = tf.shape(x) |
| 49 | + epsilon = tf.random_normal([shape[0], shape[1], 1, hparams.z_size]) |
| 50 | + z = mu + tf.exp(log_sigma / 2) * epsilon |
| 51 | + dense = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense") |
| 52 | + kl = 0.5 * tf.reduce_mean( |
| 53 | + tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1) |
| 54 | + return dense, tf.reduce_mean(kl) |
| 55 | + |
| 56 | + |
| 57 | +def compress_vae(inputs, hparams, name): |
| 58 | + """Compress, then VAE.""" |
| 59 | + with tf.variable_scope(name): |
| 60 | + # Run compression by strided convs. |
| 61 | + cur = tf.expand_dims(inputs, axis=2) |
| 62 | + for i in xrange(hparams.num_compress_steps): |
| 63 | + cur = common_layers.conv_block( |
| 64 | + cur, hparams.hidden_size, [((1, 1), (2, 1))], |
| 65 | + strides=(2, 1), name="compress_%d" % i) |
| 66 | + |
| 67 | + # Convolve and ReLu to get state. |
| 68 | + cur = common_layers.conv_block( |
| 69 | + cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") |
| 70 | + |
| 71 | + cur, kl_loss = vae(cur, hparams, name="vae") |
| 72 | + return cur, kl_loss |
| 73 | + |
| 74 | + |
| 75 | +def vae_transformer_internal(inputs, targets, target_space, hparams): |
| 76 | + """VAE Transformer, main step used for training.""" |
| 77 | + with tf.variable_scope("vae_transformer"): |
| 78 | + is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN |
| 79 | + # Prepare inputs, targets, and k. |
| 80 | + inputs = common_layers.flatten4d3d(inputs) |
| 81 | + targets = common_layers.flatten4d3d(targets) |
| 82 | + k = 2**hparams.num_compress_steps |
| 83 | + _, targets = common_layers.pad_to_same_length( |
| 84 | + inputs, targets, final_length_divisible_by=k) |
| 85 | + |
| 86 | + # Transformer preparations and encoder. |
| 87 | + (encoder_input, encoder_self_attention_bias, |
| 88 | + encoder_decoder_attention_bias) = transformer.transformer_prepare_encoder( |
| 89 | + inputs, target_space, hparams) |
| 90 | + residual_fn = transformer.get_residual_fn(hparams) |
| 91 | + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) |
| 92 | + encoder_output = transformer.transformer_encoder( |
| 93 | + encoder_input, residual_fn, encoder_self_attention_bias, hparams) |
| 94 | + |
| 95 | + def get_decoder_autoregressive(): |
| 96 | + """Decoder input for autoregressive computation.""" |
| 97 | + (a, b) = transformer.transformer_prepare_decoder(targets, hparams) |
| 98 | + return (a, b, tf.constant(0.0)) |
| 99 | + |
| 100 | + # 10% of the time we compress all-zeros, as will be at decoding start. |
| 101 | + prob_targets = 0.9 if is_training else 1.0 |
| 102 | + to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets), |
| 103 | + lambda: targets, lambda: tf.zeros_like(targets)) |
| 104 | + z, kl_loss = compress_vae(to_compress, hparams, "vae") |
| 105 | + # Decompress. |
| 106 | + for i in xrange(hparams.num_compress_steps): |
| 107 | + j = hparams.num_hidden_layers - i - 1 |
| 108 | + z = decompress(z, hparams, "decompress_%d" % j) |
| 109 | + |
| 110 | + def get_decoder_from_vae(): |
| 111 | + """Decoder input computed by VAE.""" |
| 112 | + # Return decoder stuff. |
| 113 | + (a, b) = transformer.transformer_prepare_decoder( |
| 114 | + tf.squeeze(z, axis=2), hparams) |
| 115 | + return (a, b, kl_loss) |
| 116 | + |
| 117 | + # Randomize decoder inputs.. |
| 118 | + prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7 |
| 119 | + step = tf.to_float(tf.contrib.framework.get_global_step()) |
| 120 | + if not is_training: |
| 121 | + prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0), |
| 122 | + lambda: tf.constant(1.0)) |
| 123 | + (decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond( |
| 124 | + tf.less(tf.random_uniform([]), prob_do_vae), |
| 125 | + get_decoder_from_vae, get_decoder_autoregressive) |
| 126 | + |
| 127 | + # Transformer decoder. |
| 128 | + decoder_output = transformer.transformer_decoder( |
| 129 | + decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, |
| 130 | + encoder_decoder_attention_bias, hparams) |
| 131 | + decoder_output = tf.expand_dims(decoder_output, 2) |
| 132 | + |
| 133 | + cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0), |
| 134 | + lambda: tf.constant(0.0)) |
| 135 | + prob_self = 0.4 if is_training else cond_self |
| 136 | + (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self), |
| 137 | + lambda: (z, kl_loss), |
| 138 | + lambda: (decoder_output, kl_loss2)) |
| 139 | + |
| 140 | + kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0 |
| 141 | + return ret, kl_loss |
| 142 | + |
| 143 | + |
| 144 | +@registry.register_model |
| 145 | +class TransformerVAE(t2t_model.T2TModel): |
| 146 | + |
| 147 | + def model_fn_body(self, features): |
| 148 | + return vae_transformer_internal( |
| 149 | + features["inputs"], features["targets"], features["target_space_id"], |
| 150 | + self._hparams) |
| 151 | + |
| 152 | + def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, |
| 153 | + last_position_only=False, alpha=0.0): |
| 154 | + """A inference method, see T2TModel.""" |
| 155 | + if not features: |
| 156 | + features = {} |
| 157 | + inputs_old = None |
| 158 | + if "inputs" in features and len(features["inputs"].shape) < 4: |
| 159 | + inputs_old = features["inputs"] |
| 160 | + features["inputs"] = tf.expand_dims(features["inputs"], 2) |
| 161 | + |
| 162 | + # Create an initial targets tensor. |
| 163 | + if "partial_targets" in features: |
| 164 | + initial_output = tf.convert_to_tensor(features["partial_targets"]) |
| 165 | + else: |
| 166 | + batch_size = tf.shape(features["inputs"])[0] |
| 167 | + initial_output = tf.zeros((batch_size, 1, 1, 1), dtype=tf.int64) |
| 168 | + |
| 169 | + features["targets"] = initial_output |
| 170 | + sharded_logits, _ = self.model_fn( |
| 171 | + features, False, last_position_only=last_position_only) |
| 172 | + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) |
| 173 | + samples = tf.concat(sharded_samples, 0) |
| 174 | + if inputs_old is not None: # Restore to not confuse Estimator. |
| 175 | + features["inputs"] = inputs_old |
| 176 | + return samples |
| 177 | + |
| 178 | + |
| 179 | +@registry.register_hparams |
| 180 | +def transformer_vae_small(): |
| 181 | + """Set of hyperparameters.""" |
| 182 | + hparams = transformer.transformer_small() |
| 183 | + hparams.add_hparam("z_size", 128) |
| 184 | + hparams.add_hparam("num_compress_steps", 4) |
| 185 | + return hparams |
0 commit comments