Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 328911c

Browse files
authored
Merge pull request #206 from aidangomez/master
Algorithmic Shift and Vigenere Cipher Data Generator
2 parents 82cce52 + 90e72af commit 328911c

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensor2tensor.data_generators import algorithmic
2323
from tensor2tensor.data_generators import algorithmic_math
2424
from tensor2tensor.data_generators import audio
25+
from tensor2tensor.data_generators import cipher
2526
from tensor2tensor.data_generators import image
2627
from tensor2tensor.data_generators import lm1b
2728
from tensor2tensor.data_generators import ptb
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from collections import deque
2+
import numpy as np
3+
4+
from tensor2tensor.data_generators import problem, algorithmic
5+
from tensor2tensor.utils import registry
6+
7+
8+
@registry.register_problem
9+
class CipherShift5(algorithmic.AlgorithmicProblem):
10+
11+
@property
12+
def num_symbols(self):
13+
return 5
14+
15+
@property
16+
def distribution(self):
17+
return [0.4, 0.3, 0.2, 0.08, 0.02]
18+
19+
@property
20+
def shift(self):
21+
return 1
22+
23+
@property
24+
def train_generator(self):
25+
"""Generator; takes 3 args: nbr_symbols, max_length, nbr_cases."""
26+
27+
def _gen(nbr_symbols, max_length, nbr_cases):
28+
plain_vocab = range(nbr_symbols)
29+
indices = generate_plaintext_random(plain_vocab, self.distribution,
30+
nbr_cases, max_length)
31+
codes = encipher_shift(indices, plain_vocab, self.shift)
32+
33+
for plain, code in zip(indices, codes):
34+
yield {
35+
"X": plain,
36+
"Y": code,
37+
}
38+
39+
return _gen
40+
41+
@property
42+
def train_length(self):
43+
return 100
44+
45+
@property
46+
def dev_length(self):
47+
return self.train_length
48+
49+
50+
@registry.register_problem
51+
class CipherVigenere5(algorithmic.AlgorithmicProblem):
52+
53+
@property
54+
def num_symbols(self):
55+
return 5
56+
57+
@property
58+
def distribution(self):
59+
return [0.4, 0.3, 0.2, 0.08, 0.02]
60+
61+
@property
62+
def key(self):
63+
return [1, 3]
64+
65+
@property
66+
def train_generator(self):
67+
"""Generator; takes 3 args: nbr_symbols, max_length, nbr_cases."""
68+
69+
def _gen(nbr_symbols, max_length, nbr_cases):
70+
plain_vocab = range(nbr_symbols)
71+
indices = generate_plaintext_random(plain_vocab, self.distribution,
72+
nbr_cases, max_length)
73+
codes = encipher_vigenere(indices, plain_vocab, self.key)
74+
75+
for plain, code in zip(indices, codes):
76+
yield {
77+
"X": plain,
78+
"Y": code,
79+
}
80+
81+
return _gen
82+
83+
@property
84+
def train_length(self):
85+
return 200
86+
87+
@property
88+
def dev_length(self):
89+
return self.train_length
90+
91+
92+
@registry.register_problem
93+
class CipherShift200(CipherShift5):
94+
95+
@property
96+
def num_symbols(self):
97+
return 200
98+
99+
@property
100+
def distribution(self):
101+
vals = range(self.num_symbols)
102+
val_sum = sum(vals)
103+
return [v / val_sum for v in vals]
104+
105+
106+
@registry.register_problem
107+
class CipherVigenere200(CipherVigenere5):
108+
109+
@property
110+
def num_symbols(self):
111+
return 200
112+
113+
@property
114+
def distribution(self):
115+
vals = range(self.num_symbols)
116+
val_sum = sum(vals)
117+
return [v / val_sum for v in vals]
118+
119+
@property
120+
def key(self):
121+
return [1, 3]
122+
123+
124+
class Layer():
125+
"""A single layer for shift"""
126+
127+
def __init__(self, vocab, shift):
128+
"""Initialize shift layer
129+
130+
Args:
131+
vocab (list of String): the vocabulary
132+
shift (Integer): the amount of shift apply to the alphabet. Positive number implies
133+
shift to the right, negative number implies shift to the left.
134+
"""
135+
self.shift = shift
136+
alphabet = vocab
137+
shifted_alphabet = deque(alphabet)
138+
shifted_alphabet.rotate(shift)
139+
self.encrypt = dict(zip(alphabet, list(shifted_alphabet)))
140+
self.decrypt = dict(zip(list(shifted_alphabet), alphabet))
141+
142+
def encrypt_character(self, character):
143+
return self.encrypt[character]
144+
145+
def decrypt_character(self, character):
146+
return self.decrypt[character]
147+
148+
149+
def generate_plaintext_random(plain_vocab, distribution, train_samples,
150+
length):
151+
"""Generates samples of text from the provided vocabulary.
152+
Returns:
153+
train_indices (np.array of Integers): random integers generated for training.
154+
shape = [num_samples, length]
155+
test_indices (np.array of Integers): random integers generated for testing.
156+
shape = [num_samples, length]
157+
plain_vocab (list of Integers): unique vocabularies.
158+
"""
159+
if distribution is not None:
160+
assert len(distribution) == len(plain_vocab)
161+
162+
train_indices = np.random.choice(
163+
range(len(plain_vocab)), (train_samples, length), p=distribution)
164+
165+
return train_indices
166+
167+
168+
def encipher_shift(plaintext, plain_vocab, shift):
169+
"""Encrypt plain text with a single shift layer
170+
Args:
171+
plaintext (list of list of Strings): a list of plain text to encrypt.
172+
plain_vocab (list of Integer): unique vocabularies being used.
173+
shift (Integer): number of shift, shift to the right if shift is positive.
174+
Returns:
175+
ciphertext (list of Strings): encrypted plain text.
176+
"""
177+
ciphertext = []
178+
cipher = Layer(plain_vocab, shift)
179+
180+
for i, sentence in enumerate(plaintext):
181+
cipher_sentence = []
182+
for j, character in enumerate(sentence):
183+
encrypted_char = cipher.encrypt_character(character)
184+
cipher_sentence.append(encrypted_char)
185+
ciphertext.append(cipher_sentence)
186+
187+
return ciphertext
188+
189+
190+
def encipher_vigenere(plaintext, plain_vocab, key):
191+
"""Encrypt plain text with given key
192+
Args:
193+
plaintext (list of list of Strings): a list of plain text to encrypt.
194+
plain_vocab (list of Integer): unique vocabularies being used.
195+
key (list of Integer): key to encrypt cipher using Vigenere table.
196+
Returns:
197+
ciphertext (list of Strings): encrypted plain text.
198+
"""
199+
ciphertext = []
200+
# generate Vigenere table
201+
layers = []
202+
for i in range(len(plain_vocab)):
203+
layers.append(Layer(plain_vocab, i))
204+
205+
for i, sentence in enumerate(plaintext):
206+
cipher_sentence = []
207+
for j, character in enumerate(sentence):
208+
key_idx = key[j % len(key)]
209+
encrypted_char = layers[key_idx].encrypt_character(character)
210+
cipher_sentence.append(encrypted_char)
211+
ciphertext.append(cipher_sentence)
212+
213+
return ciphertext

0 commit comments

Comments
 (0)