Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7efdbee

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Small transformer models (reasonable translations in 1h on 1080).
PiperOrigin-RevId: 164207044
1 parent 554973f commit 7efdbee

File tree

4 files changed

+272
-2
lines changed

4 files changed

+272
-2
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensor2tensor.data_generators import algorithmic
2323
from tensor2tensor.data_generators import algorithmic_math
2424
from tensor2tensor.data_generators import audio
25+
from tensor2tensor.data_generators import cipher
2526
from tensor2tensor.data_generators import desc2code
2627
from tensor2tensor.data_generators import image
2728
from tensor2tensor.data_generators import lm1b
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Cipher data generators."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from collections import deque
22+
23+
# Dependency imports
24+
25+
import numpy as np
26+
27+
from tensor2tensor.data_generators import algorithmic
28+
from tensor2tensor.utils import registry
29+
30+
31+
@registry.register_problem
32+
class CipherShift5(algorithmic.AlgorithmicProblem):
33+
"""Shift cipher."""
34+
35+
@property
36+
def num_symbols(self):
37+
return 5
38+
39+
@property
40+
def distribution(self):
41+
return [0.4, 0.3, 0.2, 0.08, 0.02]
42+
43+
@property
44+
def shift(self):
45+
return 1
46+
47+
@property
48+
def train_generator(self):
49+
"""Generator; takes 3 args: nbr_symbols, max_length, nbr_cases."""
50+
51+
def _gen(nbr_symbols, max_length, nbr_cases):
52+
plain_vocab = range(nbr_symbols)
53+
indices = generate_plaintext_random(plain_vocab, self.distribution,
54+
nbr_cases, max_length)
55+
codes = encipher_shift(indices, plain_vocab, self.shift)
56+
57+
for plain, code in zip(indices, codes):
58+
yield {
59+
"X": plain,
60+
"Y": code,
61+
}
62+
63+
return _gen
64+
65+
@property
66+
def train_length(self):
67+
return 100
68+
69+
@property
70+
def dev_length(self):
71+
return self.train_length
72+
73+
74+
@registry.register_problem
75+
class CipherVigenere5(algorithmic.AlgorithmicProblem):
76+
"""Vinegre cipher."""
77+
78+
@property
79+
def num_symbols(self):
80+
return 5
81+
82+
@property
83+
def distribution(self):
84+
return [0.4, 0.3, 0.2, 0.08, 0.02]
85+
86+
@property
87+
def key(self):
88+
return [1, 3]
89+
90+
@property
91+
def train_generator(self):
92+
"""Generator; takes 3 args: nbr_symbols, max_length, nbr_cases."""
93+
94+
def _gen(nbr_symbols, max_length, nbr_cases):
95+
plain_vocab = range(nbr_symbols)
96+
indices = generate_plaintext_random(plain_vocab, self.distribution,
97+
nbr_cases, max_length)
98+
codes = encipher_vigenere(indices, plain_vocab, self.key)
99+
100+
for plain, code in zip(indices, codes):
101+
yield {
102+
"X": plain,
103+
"Y": code,
104+
}
105+
106+
return _gen
107+
108+
@property
109+
def train_length(self):
110+
return 200
111+
112+
@property
113+
def dev_length(self):
114+
return self.train_length
115+
116+
117+
@registry.register_problem
118+
class CipherShift200(CipherShift5):
119+
"""Shift cipher."""
120+
121+
@property
122+
def num_symbols(self):
123+
return 200
124+
125+
@property
126+
def distribution(self):
127+
vals = range(self.num_symbols)
128+
val_sum = sum(vals)
129+
return [v / val_sum for v in vals]
130+
131+
132+
@registry.register_problem
133+
class CipherVigenere200(CipherVigenere5):
134+
"""Vinegre cipher."""
135+
136+
@property
137+
def num_symbols(self):
138+
return 200
139+
140+
@property
141+
def distribution(self):
142+
vals = range(self.num_symbols)
143+
val_sum = sum(vals)
144+
return [v / val_sum for v in vals]
145+
146+
@property
147+
def key(self):
148+
return [1, 3]
149+
150+
151+
class Layer(object):
152+
"""A single layer for shift."""
153+
154+
def __init__(self, vocab, shift):
155+
"""Initialize shift layer.
156+
157+
Args:
158+
vocab: (list of String) the vocabulary
159+
shift: (Integer) the amount of shift apply to the alphabet.
160+
Positive number implies shift to the right, negative number
161+
implies shift to the left.
162+
"""
163+
self.shift = shift
164+
alphabet = vocab
165+
shifted_alphabet = deque(alphabet)
166+
shifted_alphabet.rotate(shift)
167+
self.encrypt = dict(zip(alphabet, list(shifted_alphabet)))
168+
self.decrypt = dict(zip(list(shifted_alphabet), alphabet))
169+
170+
def encrypt_character(self, character):
171+
return self.encrypt[character]
172+
173+
def decrypt_character(self, character):
174+
return self.decrypt[character]
175+
176+
177+
def generate_plaintext_random(plain_vocab, distribution, train_samples,
178+
length):
179+
"""Generates samples of text from the provided vocabulary.
180+
181+
Args:
182+
plain_vocab: vocabulary.
183+
distribution: distribution.
184+
train_samples: samples for training.
185+
length: length.
186+
187+
Returns:
188+
train_indices (np.array of Integers): random integers for training.
189+
shape = [num_samples, length]
190+
test_indices (np.array of Integers): random integers for testing.
191+
shape = [num_samples, length]
192+
plain_vocab (list of Integers): unique vocabularies.
193+
"""
194+
if distribution is not None:
195+
assert len(distribution) == len(plain_vocab)
196+
197+
train_indices = np.random.choice(
198+
range(len(plain_vocab)), (train_samples, length), p=distribution)
199+
200+
return train_indices
201+
202+
203+
def encipher_shift(plaintext, plain_vocab, shift):
204+
"""Encrypt plain text with a single shift layer.
205+
206+
Args:
207+
plaintext (list of list of Strings): a list of plain text to encrypt.
208+
plain_vocab (list of Integer): unique vocabularies being used.
209+
shift (Integer): number of shift, shift to the right if shift is positive.
210+
Returns:
211+
ciphertext (list of Strings): encrypted plain text.
212+
"""
213+
ciphertext = []
214+
cipher = Layer(plain_vocab, shift)
215+
216+
for _, sentence in enumerate(plaintext):
217+
cipher_sentence = []
218+
for _, character in enumerate(sentence):
219+
encrypted_char = cipher.encrypt_character(character)
220+
cipher_sentence.append(encrypted_char)
221+
ciphertext.append(cipher_sentence)
222+
223+
return ciphertext
224+
225+
226+
def encipher_vigenere(plaintext, plain_vocab, key):
227+
"""Encrypt plain text with given key.
228+
229+
Args:
230+
plaintext (list of list of Strings): a list of plain text to encrypt.
231+
plain_vocab (list of Integer): unique vocabularies being used.
232+
key (list of Integer): key to encrypt cipher using Vigenere table.
233+
234+
Returns:
235+
ciphertext (list of Strings): encrypted plain text.
236+
"""
237+
ciphertext = []
238+
# generate Vigenere table
239+
layers = []
240+
for i in range(len(plain_vocab)):
241+
layers.append(Layer(plain_vocab, i))
242+
243+
for i, sentence in enumerate(plaintext):
244+
cipher_sentence = []
245+
for j, character in enumerate(sentence):
246+
key_idx = key[j % len(key)]
247+
encrypted_char = layers[key_idx].encrypt_character(character)
248+
cipher_sentence.append(encrypted_char)
249+
ciphertext.append(cipher_sentence)
250+
251+
return ciphertext

tensor2tensor/layers/common_layers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def inverse_exp_decay(max_step, min_value=0.01):
5959
return inv_base**tf.maximum(float(max_step) - step, 0.0)
6060

6161

62+
def inverse_lin_decay(max_step, min_value=0.01):
63+
"""Inverse-decay linearly from 0.01 to 1.0 reached at max_step."""
64+
step = tf.to_float(tf.contrib.framework.get_global_step())
65+
progress = tf.minimum(step / float(max_step), 1.0)
66+
return progress * (1.0 - min_value) + min_value
67+
68+
6269
def shakeshake2_py(x, y, equal=False, individual=False):
6370
"""The shake-shake sum of 2 tensors, python version."""
6471
if equal:

tensor2tensor/models/transformer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,19 @@ def transformer_parsing_ice():
386386
@registry.register_hparams
387387
def transformer_tiny():
388388
hparams = transformer_base()
389-
hparams.hidden_size = 64
390-
hparams.filter_size = 128
389+
hparams.num_hidden_layers = 2
390+
hparams.hidden_size = 128
391+
hparams.filter_size = 512
392+
hparams.num_heads = 4
393+
return hparams
394+
395+
396+
@registry.register_hparams
397+
def transformer_small():
398+
hparams = transformer_base()
399+
hparams.num_hidden_layers = 2
400+
hparams.hidden_size = 256
401+
hparams.filter_size = 1024
391402
hparams.num_heads = 4
392403
return hparams
393404

0 commit comments

Comments
 (0)