Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f25af0f

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Share desc2code source vocab with translation, baseline to play with VAE.
PiperOrigin-RevId: 164315503
1 parent a0bd017 commit f25af0f

File tree

3 files changed

+218
-36
lines changed

3 files changed

+218
-36
lines changed

tensor2tensor/data_generators/desc2code.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
_DESC_DIR_NAME = "description"
4545
_CODE_PY_DIR_NAME = "solutions_python"
4646

47-
_VOCAB_EN_FILENAME = "vocab_desc2code_tok_en"
48-
_VOCAB_PY_FILENAME = "vocab_desc2code_tok_py"
47+
_VOCAB_EN_FILENAME = "vocab.endefr"
48+
_VOCAB_PY_FILENAME = "vocab.py"
4949

5050
# Struct containing a coding problem (contains the paths to the descriptions
5151
# and code files)
@@ -61,21 +61,43 @@ def is_character_level(self):
6161

6262
@property
6363
def num_shards(self):
64-
return 100
64+
return 10
6565

6666
@property
6767
def use_subword_tokenizer(self):
6868
return True
6969

70+
@property
71+
def input_vocab_size(self):
72+
return 2**15 # 32k
73+
74+
@property
75+
def target_vocab_size(self):
76+
return 2**12 # 4k
77+
78+
@property
79+
def vocab_input_filename(self):
80+
return "{}.{}".format(_VOCAB_EN_FILENAME, self.input_vocab_size)
81+
82+
@property
83+
def vocab_target_filename(self):
84+
return "{}.{}".format(_VOCAB_PY_FILENAME, self.target_vocab_size)
85+
86+
def feature_encoders(self, data_dir):
87+
source_vocab_filename = os.path.join(data_dir, self.vocab_input_filename)
88+
target_vocab_filename = os.path.join(data_dir, self.vocab_target_filename)
89+
source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
90+
target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
91+
return {
92+
"inputs": source_token,
93+
"targets": target_token,
94+
}
95+
7096

7197
@registry.register_problem("desc2code_py")
7298
class Desc2CodePyProblem(Desc2CodeProblem):
7399
"""Description2Code for python problem."""
74100

75-
@property
76-
def targeted_vocab_size(self):
77-
return 2**13 # 8192
78-
79101
@property
80102
def input_space_id(self):
81103
return problem.SpaceID.EN_TOK
@@ -84,14 +106,6 @@ def input_space_id(self):
84106
def target_space_id(self):
85107
return problem.SpaceID.PY_TOK
86108

87-
@property
88-
def vocab_input_filename(self):
89-
return "{}.{}".format(_VOCAB_EN_FILENAME, self.targeted_vocab_size)
90-
91-
@property
92-
def vocab_target_filename(self):
93-
return "{}.{}".format(_VOCAB_PY_FILENAME, self.targeted_vocab_size)
94-
95109
def train_generator(self, data_dir, tmp_dir, train):
96110
# Called twice: for train and test
97111

@@ -135,27 +149,19 @@ def generator_samples_content(get_source, get_target):
135149
elif sample.code_files: # Only take the source if a target exists
136150
yield source, target
137151

138-
def generator_source():
139-
for source, _ in generator_samples_content(True, False):
140-
yield source.strip()
141-
142152
def generator_target():
143153
for _, target in generator_samples_content(False, True):
144154
yield target.strip()
145155

146156
# Generate vocab for both source and target
147157

148-
source_vocab = generator_utils.get_or_generate_vocab_inner(
149-
data_dir=data_dir,
150-
vocab_filename=self.vocab_input_filename,
151-
vocab_size=self.targeted_vocab_size,
152-
generator_fn=generator_source,
153-
)
158+
source_vocab = generator_utils.get_or_generate_vocab(
159+
data_dir, tmp_dir, self.vocab_input_filename, self.input_vocab_size)
154160

155161
target_vocab = generator_utils.get_or_generate_vocab_inner(
156162
data_dir=data_dir,
157163
vocab_filename=self.vocab_target_filename,
158-
vocab_size=self.targeted_vocab_size,
164+
vocab_size=self.target_vocab_size,
159165
generator_fn=generator_target,
160166
)
161167

@@ -169,16 +175,6 @@ def generator_target():
169175
"targets": target_ints,
170176
}
171177

172-
def feature_encoders(self, data_dir):
173-
source_vocab_filename = os.path.join(data_dir, self.vocab_input_filename)
174-
target_vocab_filename = os.path.join(data_dir, self.vocab_target_filename)
175-
source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
176-
target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
177-
return {
178-
"inputs": source_token,
179-
"targets": target_token,
180-
}
181-
182178

183179
# Utils functions
184180

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@
3838
from tensor2tensor.models import transformer
3939
from tensor2tensor.models import transformer_alternative
4040
from tensor2tensor.models import transformer_moe
41+
from tensor2tensor.models import transformer_vae
4142
from tensor2tensor.models import xception
4243
# pylint: enable=unused-import
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""VAE Transformer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
from six.moves import xrange # pylint: disable=redefined-builtin
25+
26+
from tensor2tensor.layers import common_layers
27+
from tensor2tensor.models import transformer
28+
from tensor2tensor.utils import registry
29+
from tensor2tensor.utils import t2t_model
30+
31+
import tensorflow as tf
32+
33+
34+
def decompress(source, hparams, name):
35+
"""Decompression function."""
36+
with tf.variable_scope(name):
37+
shape = tf.shape(source)
38+
thicker = common_layers.conv_block(
39+
source, hparams.hidden_size * 2, [((1, 1), (1, 1))],
40+
name="decompress_conv")
41+
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
42+
43+
44+
def vae(x, hparams, name):
45+
with tf.variable_scope(name):
46+
mu = tf.layers.dense(x, hparams.z_size, name="mu")
47+
log_sigma = tf.layers.dense(x, hparams.z_size, name="log_sigma")
48+
shape = tf.shape(x)
49+
epsilon = tf.random_normal([shape[0], shape[1], 1, hparams.z_size])
50+
z = mu + tf.exp(log_sigma / 2) * epsilon
51+
dense = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")
52+
kl = 0.5 * tf.reduce_mean(
53+
tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1)
54+
return dense, tf.reduce_mean(kl)
55+
56+
57+
def compress_vae(inputs, hparams, name):
58+
"""Compress, then VAE."""
59+
with tf.variable_scope(name):
60+
# Run compression by strided convs.
61+
cur = tf.expand_dims(inputs, axis=2)
62+
for i in xrange(hparams.num_compress_steps):
63+
cur = common_layers.conv_block(
64+
cur, hparams.hidden_size, [((1, 1), (2, 1))],
65+
strides=(2, 1), name="compress_%d" % i)
66+
67+
# Convolve and ReLu to get state.
68+
cur = common_layers.conv_block(
69+
cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
70+
71+
cur, kl_loss = vae(cur, hparams, name="vae")
72+
return cur, kl_loss
73+
74+
75+
def vae_transformer_internal(inputs, targets, target_space, hparams):
76+
"""VAE Transformer, main step used for training."""
77+
with tf.variable_scope("vae_transformer"):
78+
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
79+
# Prepare inputs, targets, and k.
80+
inputs = common_layers.flatten4d3d(inputs)
81+
targets = common_layers.flatten4d3d(targets)
82+
k = 2**hparams.num_compress_steps
83+
_, targets = common_layers.pad_to_same_length(
84+
inputs, targets, final_length_divisible_by=k)
85+
86+
# Transformer preparations and encoder.
87+
(encoder_input, encoder_self_attention_bias,
88+
encoder_decoder_attention_bias) = transformer.transformer_prepare_encoder(
89+
inputs, target_space, hparams)
90+
residual_fn = transformer.get_residual_fn(hparams)
91+
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
92+
encoder_output = transformer.transformer_encoder(
93+
encoder_input, residual_fn, encoder_self_attention_bias, hparams)
94+
95+
def get_decoder_autoregressive():
96+
"""Decoder input for autoregressive computation."""
97+
(a, b) = transformer.transformer_prepare_decoder(targets, hparams)
98+
return (a, b, tf.constant(0.0))
99+
100+
# 10% of the time we compress all-zeros, as will be at decoding start.
101+
prob_targets = 0.9 if is_training else 1.0
102+
to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
103+
lambda: targets, lambda: tf.zeros_like(targets))
104+
z, kl_loss = compress_vae(to_compress, hparams, "vae")
105+
# Decompress.
106+
for i in xrange(hparams.num_compress_steps):
107+
j = hparams.num_hidden_layers - i - 1
108+
z = decompress(z, hparams, "decompress_%d" % j)
109+
110+
def get_decoder_from_vae():
111+
"""Decoder input computed by VAE."""
112+
# Return decoder stuff.
113+
(a, b) = transformer.transformer_prepare_decoder(
114+
tf.squeeze(z, axis=2), hparams)
115+
return (a, b, kl_loss)
116+
117+
# Randomize decoder inputs..
118+
prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7
119+
step = tf.to_float(tf.contrib.framework.get_global_step())
120+
if not is_training:
121+
prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0),
122+
lambda: tf.constant(1.0))
123+
(decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond(
124+
tf.less(tf.random_uniform([]), prob_do_vae),
125+
get_decoder_from_vae, get_decoder_autoregressive)
126+
127+
# Transformer decoder.
128+
decoder_output = transformer.transformer_decoder(
129+
decoder_input, encoder_output, residual_fn, decoder_self_attention_bias,
130+
encoder_decoder_attention_bias, hparams)
131+
decoder_output = tf.expand_dims(decoder_output, 2)
132+
133+
cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0),
134+
lambda: tf.constant(0.0))
135+
prob_self = 0.4 if is_training else cond_self
136+
(ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self),
137+
lambda: (z, kl_loss),
138+
lambda: (decoder_output, kl_loss2))
139+
140+
kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0
141+
return ret, kl_loss
142+
143+
144+
@registry.register_model
145+
class TransformerVAE(t2t_model.T2TModel):
146+
147+
def model_fn_body(self, features):
148+
return vae_transformer_internal(
149+
features["inputs"], features["targets"], features["target_space_id"],
150+
self._hparams)
151+
152+
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
153+
last_position_only=False, alpha=0.0):
154+
"""A inference method, see T2TModel."""
155+
if not features:
156+
features = {}
157+
inputs_old = None
158+
if "inputs" in features and len(features["inputs"].shape) < 4:
159+
inputs_old = features["inputs"]
160+
features["inputs"] = tf.expand_dims(features["inputs"], 2)
161+
162+
# Create an initial targets tensor.
163+
if "partial_targets" in features:
164+
initial_output = tf.convert_to_tensor(features["partial_targets"])
165+
else:
166+
batch_size = tf.shape(features["inputs"])[0]
167+
initial_output = tf.zeros((batch_size, 1, 1, 1), dtype=tf.int64)
168+
169+
features["targets"] = initial_output
170+
sharded_logits, _ = self.model_fn(
171+
features, False, last_position_only=last_position_only)
172+
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
173+
samples = tf.concat(sharded_samples, 0)
174+
if inputs_old is not None: # Restore to not confuse Estimator.
175+
features["inputs"] = inputs_old
176+
return samples
177+
178+
179+
@registry.register_hparams
180+
def transformer_vae_small():
181+
"""Set of hyperparameters."""
182+
hparams = transformer.transformer_small()
183+
hparams.add_hparam("z_size", 128)
184+
hparams.add_hparam("num_compress_steps", 4)
185+
return hparams

0 commit comments

Comments
 (0)