Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d30ec6b

Browse files
author
Ryan Sepassi
committed
Add TransformerEncoder and TransformerDecoder models
PiperOrigin-RevId: 164785525
1 parent b669110 commit d30ec6b

File tree

10 files changed

+143
-159
lines changed

10 files changed

+143
-159
lines changed

tensor2tensor/bin/t2t-datagen

100755100644
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8282
"algorithmic_algebra_inverse": (
8383
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8484
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
85+
"ice_parsing_tokens": (
86+
lambda: wmt.tabbed_parsing_token_generator(
87+
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
88+
lambda: wmt.tabbed_parsing_token_generator(
89+
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
90+
"ice_parsing_characters": (
91+
lambda: wmt.tabbed_parsing_character_generator(
92+
FLAGS.data_dir, FLAGS.tmp_dir, True),
93+
lambda: wmt.tabbed_parsing_character_generator(
94+
FLAGS.data_dir, FLAGS.tmp_dir, False)),
8595
"wmt_parsing_tokens_8k": (
8696
lambda: wmt.parsing_token_generator(
8797
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),

tensor2tensor/bin/t2t-trainer

100755100644
File mode changed.

tensor2tensor/data_generators/all_problems.py

100755100644
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from tensor2tensor.data_generators import wiki
3232
from tensor2tensor.data_generators import wmt
3333
from tensor2tensor.data_generators import wsj_parsing
34-
from tensor2tensor.data_generators import ice_parsing
3534

3635

3736
# Problem modules that require optional dependencies

tensor2tensor/data_generators/generator_utils.py

100755100644
File mode changed.

tensor2tensor/data_generators/ice_parsing.py

Lines changed: 0 additions & 117 deletions
This file was deleted.

tensor2tensor/data_generators/problem_hparams.py

100755100644
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,39 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
462462
return p
463463

464464

465+
def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
466+
"""Icelandic to parse tree translation benchmark.
467+
468+
Args:
469+
model_hparams: a tf.contrib.training.HParams
470+
wrong_source_vocab_size: a number used in the filename indicating the
471+
approximate vocabulary size. This is not to be confused with the actual
472+
vocabulary size.
473+
474+
Returns:
475+
A tf.contrib.training.HParams object.
476+
"""
477+
p = default_problem_hparams()
478+
# This vocab file must be present within the data directory.
479+
source_vocab_filename = os.path.join(
480+
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
481+
target_vocab_filename = os.path.join(model_hparams.data_dir,
482+
"ice_target.vocab.256")
483+
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
484+
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
485+
p.input_modality = {
486+
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
487+
}
488+
p.target_modality = (registry.Modalities.SYMBOL, 256)
489+
p.vocabulary = {
490+
"inputs": source_subtokenizer,
491+
"targets": target_subtokenizer,
492+
}
493+
p.input_space_id = 18 # Icelandic tokens
494+
p.target_space_id = 19 # Icelandic parse tokens
495+
return p
496+
497+
465498
def img2img_imagenet(unused_model_hparams):
466499
"""Image 2 Image for imagenet dataset."""
467500
p = default_problem_hparams()
@@ -511,6 +544,10 @@ def image_celeba(unused_model_hparams):
511544
lm1b_32k,
512545
"wiki_32k":
513546
wiki_32k,
547+
"ice_parsing_characters":
548+
wmt_parsing_characters,
549+
"ice_parsing_tokens":
550+
lambda p: ice_parsing_tokens(p, 2**13),
514551
"wmt_parsing_tokens_8k":
515552
lambda p: wmt_parsing_tokens(p, 2**13),
516553
"wsj_parsing_tokens_16k":

tensor2tensor/data_generators/wmt.py

100755100644
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,28 @@ def target_space_id(self):
648648
return problem.SpaceID.CS_CHR
649649

650650

651+
def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
652+
source_vocab_size, target_vocab_size):
653+
"""Generate source and target data from a single file."""
654+
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
655+
data_dir, tmp_dir, "parsing_train.pairs", 0,
656+
prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size)
657+
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
658+
data_dir, tmp_dir, "parsing_train.pairs", 1,
659+
prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size)
660+
filename = "parsing_%s" % ("train" if train else "dev")
661+
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
662+
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
663+
664+
665+
def tabbed_parsing_character_generator(tmp_dir, train):
666+
"""Generate source and target data from a single file."""
667+
character_vocab = text_encoder.ByteTextEncoder()
668+
filename = "parsing_%s" % ("train" if train else "dev")
669+
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
670+
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
671+
672+
651673
def parsing_token_generator(data_dir, tmp_dir, train, vocab_size):
652674
symbolizer_vocab = generator_utils.get_or_generate_vocab(
653675
data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size)

tensor2tensor/models/transformer.py

100755100644
Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,66 @@ def model_fn_body(self, features):
5555
(decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(
5656
targets, hparams)
5757

58-
encoder_input = tf.nn.dropout(
59-
encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
60-
decoder_input = tf.nn.dropout(
61-
decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
62-
encoder_output = transformer_encoder(
63-
encoder_input, encoder_self_attention_bias, hparams)
58+
encoder_input = tf.nn.dropout(encoder_input,
59+
1.0 - hparams.layer_prepostprocess_dropout)
60+
decoder_input = tf.nn.dropout(decoder_input,
61+
1.0 - hparams.layer_prepostprocess_dropout)
62+
encoder_output = transformer_encoder(encoder_input,
63+
encoder_self_attention_bias, hparams)
6464

6565
decoder_output = transformer_decoder(
66-
decoder_input, encoder_output,
67-
decoder_self_attention_bias,
66+
decoder_input, encoder_output, decoder_self_attention_bias,
6867
encoder_decoder_attention_bias, hparams)
6968
decoder_output = tf.expand_dims(decoder_output, 2)
7069

7170
return decoder_output
7271

7372

73+
@registry.register_model
74+
class TransformerEncoder(t2t_model.T2TModel):
75+
"""Transformer, encoder only."""
76+
77+
def model_fn_body(self, features):
78+
hparams = self._hparams
79+
inputs = features["inputs"]
80+
target_space = features["target_space_id"]
81+
82+
inputs = common_layers.flatten4d3d(inputs)
83+
84+
(encoder_input, encoder_self_attention_bias,
85+
_) = (transformer_prepare_encoder(inputs, target_space, hparams))
86+
87+
encoder_input = tf.nn.dropout(encoder_input,
88+
1.0 - hparams.layer_prepostprocess_dropout)
89+
encoder_output = transformer_encoder(encoder_input,
90+
encoder_self_attention_bias, hparams)
91+
92+
return encoder_output
93+
94+
95+
@registry.register_model
96+
class TransformerDecoder(t2t_model.T2TModel):
97+
"""Transformer, decoder only."""
98+
99+
def model_fn_body(self, features):
100+
hparams = self._hparams
101+
targets = features["targets"]
102+
103+
targets = common_layers.flatten4d3d(targets)
104+
105+
(decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(
106+
targets, hparams)
107+
108+
decoder_input = tf.nn.dropout(decoder_input,
109+
1.0 - hparams.layer_prepostprocess_dropout)
110+
111+
decoder_output = transformer_decoder(
112+
decoder_input, None, decoder_self_attention_bias, None, hparams)
113+
decoder_output = tf.expand_dims(decoder_output, 2)
114+
115+
return decoder_output
116+
117+
74118
def transformer_prepare_encoder(inputs, target_space, hparams):
75119
"""Prepare one shard of the model for the encoder.
76120
@@ -150,14 +194,11 @@ def transformer_encoder(encoder_input,
150194
with tf.variable_scope("layer_%d" % layer):
151195
with tf.variable_scope("self_attention"):
152196
y = common_attention.multihead_attention(
153-
common_layers.layer_preprocess(x, hparams),
154-
None,
155-
encoder_self_attention_bias,
197+
common_layers.layer_preprocess(
198+
x, hparams), None, encoder_self_attention_bias,
156199
hparams.attention_key_channels or hparams.hidden_size,
157200
hparams.attention_value_channels or hparams.hidden_size,
158-
hparams.hidden_size,
159-
hparams.num_heads,
160-
hparams.attention_dropout)
201+
hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
161202
x = common_layers.layer_postprocess(x, y, hparams)
162203
with tf.variable_scope("ffn"):
163204
y = transformer_ffn_layer(
@@ -196,26 +237,23 @@ def transformer_decoder(decoder_input,
196237
with tf.variable_scope("layer_%d" % layer):
197238
with tf.variable_scope("self_attention"):
198239
y = common_attention.multihead_attention(
199-
common_layers.layer_preprocess(x, hparams),
200-
None,
201-
decoder_self_attention_bias,
202-
hparams.attention_key_channels or hparams.hidden_size,
203-
hparams.attention_value_channels or hparams.hidden_size,
204-
hparams.hidden_size,
205-
hparams.num_heads,
206-
hparams.attention_dropout)
207-
x = common_layers.layer_postprocess(x, y, hparams)
208-
with tf.variable_scope("encdec_attention"):
209-
y = common_attention.multihead_attention(
210-
common_layers.layer_preprocess(x, hparams),
211-
encoder_output,
212-
encoder_decoder_attention_bias,
240+
common_layers.layer_preprocess(
241+
x, hparams), None, decoder_self_attention_bias,
213242
hparams.attention_key_channels or hparams.hidden_size,
214243
hparams.attention_value_channels or hparams.hidden_size,
215-
hparams.hidden_size,
216-
hparams.num_heads,
217-
hparams.attention_dropout)
244+
hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
218245
x = common_layers.layer_postprocess(x, y, hparams)
246+
if encoder_output is not None:
247+
assert encoder_decoder_attention_bias is not None
248+
with tf.variable_scope("encdec_attention"):
249+
y = common_attention.multihead_attention(
250+
common_layers.layer_preprocess(
251+
x, hparams), encoder_output, encoder_decoder_attention_bias,
252+
hparams.attention_key_channels or hparams.hidden_size,
253+
hparams.attention_value_channels or hparams.hidden_size,
254+
hparams.hidden_size, hparams.num_heads,
255+
hparams.attention_dropout)
256+
x = common_layers.layer_postprocess(x, y, hparams)
219257
with tf.variable_scope("ffn"):
220258
y = transformer_ffn_layer(
221259
common_layers.layer_preprocess(x, hparams), hparams)
@@ -393,7 +431,7 @@ def transformer_parsing_big():
393431

394432
@registry.register_hparams
395433
def transformer_parsing_ice():
396-
"""Hparams for parsing and tagging Icelandic text."""
434+
"""Hparams for parsing Icelandic text."""
397435
hparams = transformer_base_single_gpu()
398436
hparams.batch_size = 4096
399437
hparams.shared_embedding_and_softmax_weights = int(False)

tensor2tensor/utils/decoding.py

100755100644
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,19 +259,14 @@ def _interactive_input_fn(hparams):
259259
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
260260
# This should be longer than the longest input.
261261
const_array_size = 10000
262-
# Import readline if available for command line editing and recall
263-
try:
264-
import readline
265-
except ImportError:
266-
pass
267262
while True:
268263
prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n"
269264
" it=<input_type> ('text' or 'image' or 'label')\n"
270265
" pr=<problem_num> (set the problem number)\n"
271266
" in=<input_problem> (set the input problem number)\n"
272267
" ou=<output_problem> (set the output problem number)\n"
273268
" ns=<num_samples> (changes number of samples)\n"
274-
" dl=<decode_length> (changes decode length)\n"
269+
" dl=<decode_length> (changes decode legnth)\n"
275270
" <%s> (decode)\n"
276271
" q (quit)\n"
277272
">" % (num_samples, decode_length, "source_string"

0 commit comments

Comments
 (0)