Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1fc6766

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Create a Problem class for the lm1b dataset.
PiperOrigin-RevId: 166906734
1 parent a8ee62a commit 1fc6766

File tree

3 files changed

+115
-33
lines changed

3 files changed

+115
-33
lines changed

tensor2tensor/data_generators/lm1b.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
from six.moves import xrange # pylint: disable=redefined-builtin
3030

3131
from tensor2tensor.data_generators import generator_utils
32+
from tensor2tensor.data_generators import problem
3233
from tensor2tensor.data_generators import text_encoder
3334
from tensor2tensor.data_generators import tokenizer
35+
from tensor2tensor.utils import registry
3436

3537
import tensorflow as tf
3638

@@ -53,7 +55,7 @@ def _original_vocab(tmp_dir):
5355
"""
5456
vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/"
5557
"vocab-2016-09-10.txt")
56-
vocab_filename = os.path.basename(vocab_url)
58+
vocab_filename = os.path.basename(vocab_url + ".en")
5759
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
5860
if not os.path.exists(vocab_filepath):
5961
generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url)
@@ -140,29 +142,69 @@ def _get_or_build_subword_text_encoder(tmp_dir):
140142
return ret
141143

142144

143-
def generator(tmp_dir, train, characters=False):
144-
"""Generator for lm1b sentences.
145-
146-
Args:
147-
tmp_dir: a string.
148-
train: a boolean.
149-
characters: a boolean
150-
151-
Yields:
152-
A dictionary {"inputs": [0], "targets": [<subword ids>]}
153-
"""
154-
_maybe_download_corpus(tmp_dir)
155-
original_vocab = _original_vocab(tmp_dir)
156-
files = (_train_data_filenames(tmp_dir) if train
157-
else [_dev_data_filename(tmp_dir)])
158-
if characters:
159-
encoder = text_encoder.ByteTextEncoder()
160-
else:
161-
encoder = _get_or_build_subword_text_encoder(tmp_dir)
162-
for filepath in files:
163-
tf.logging.info("filepath = %s", filepath)
164-
for line in tf.gfile.Open(filepath):
165-
tokens = encoder.encode(
166-
_replace_oov(original_vocab, text_encoder.native_to_unicode(line)))
167-
tokens.append(EOS)
168-
yield {"inputs": [0], "targets": tokens}
145+
@registry.register_problem("languagemodel_1b32k")
146+
class LanguagemodelLm1b(problem.Text2TextProblem):
147+
"""A language model on full English Wikipedia."""
148+
149+
@property
150+
def is_character_level(self):
151+
return False
152+
153+
@property
154+
def has_inputs(self):
155+
return True
156+
157+
@property
158+
def input_space_id(self):
159+
return problem.SpaceID.EN_TOK
160+
161+
@property
162+
def target_space_id(self):
163+
return problem.SpaceID.EN_TOK
164+
165+
@property
166+
def num_shards(self):
167+
return 10
168+
169+
@property
170+
def vocab_name(self):
171+
return "vocab-2016-09-10.txt.en"
172+
173+
@property
174+
def use_subword_tokenizer(self):
175+
return True
176+
177+
@property
178+
def targeted_vocab_size(self):
179+
return 2**15 # 32768
180+
181+
@property
182+
def use_train_shards_for_dev(self):
183+
return True
184+
185+
def generator(self, tmp_dir, train, characters=False):
186+
"""Generator for lm1b sentences.
187+
188+
Args:
189+
tmp_dir: a string.
190+
train: a boolean.
191+
characters: a boolean
192+
193+
Yields:
194+
A dictionary {"inputs": [0], "targets": [<subword ids>]}
195+
"""
196+
_maybe_download_corpus(tmp_dir)
197+
original_vocab = _original_vocab(tmp_dir)
198+
files = (_train_data_filenames(tmp_dir) if train
199+
else [_dev_data_filename(tmp_dir)])
200+
if characters:
201+
encoder = text_encoder.ByteTextEncoder()
202+
else:
203+
encoder = _get_or_build_subword_text_encoder(tmp_dir)
204+
for filepath in files:
205+
tf.logging.info("filepath = %s", filepath)
206+
for line in tf.gfile.Open(filepath):
207+
tokens = encoder.encode(
208+
_replace_oov(original_vocab, text_encoder.native_to_unicode(line)))
209+
tokens.append(EOS)
210+
yield {"inputs": [0], "targets": tokens}

tensor2tensor/data_generators/text_encoder.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,19 +647,32 @@ def _init_alphabet_from_tokens(self, tokens):
647647
self._alphabet = {c for token in tokens for c in token}
648648
self._alphabet |= _ESCAPE_CHARS
649649

650-
def _load_from_file(self, filename):
651-
"""Load from a file.
650+
def _load_from_file_object(self, f):
651+
"""Load from a file object.
652652
653653
Args:
654-
filename: filename to load vocabulary from
654+
f: File object to load vocabulary from
655655
"""
656656
subtoken_strings = []
657-
with tf.gfile.Open(filename) as f:
658-
for line in f:
659-
subtoken_strings.append(native_to_unicode(line.strip()[1:-1]))
657+
for line in f:
658+
s = line.strip()
659+
# Some vocab files wrap words in single quotes, but others don't
660+
if (len(s) > 1 and ((s.startswith("'") and s.endswith("'")) or
661+
(s.startswith("\"") and s.endswith("\"")))):
662+
s = s[1:-1]
663+
subtoken_strings.append(native_to_unicode(s))
660664
self._init_subtokens_from_list(subtoken_strings)
661665
self._init_alphabet_from_tokens(subtoken_strings)
662666

667+
def _load_from_file(self, filename):
668+
"""Load from a file.
669+
670+
Args:
671+
filename: Filename to load vocabulary from
672+
"""
673+
with tf.gfile.Open(filename) as f:
674+
self._load_from_file_object(f)
675+
663676
def store_to_file(self, filename):
664677
with tf.gfile.Open(filename, "w") as f:
665678
for subtoken_string in self._all_subtoken_strings:

tensor2tensor/data_generators/text_encoder_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from __future__ import unicode_literals
2222

2323
import collections
24+
import io
2425
import os
2526
import shutil
2627

@@ -31,6 +32,14 @@
3132
import tensorflow as tf
3233

3334

35+
class NativeToUnicodeTest(tf.test.TestCase):
36+
37+
def test_native_to_unicode(self):
38+
s = r'foo bar'
39+
self.assertIsInstance(text_encoder.native_to_unicode(s), unicode)
40+
self.assertEqual(text_encoder.native_to_unicode(s), u'foo bar')
41+
42+
3443
class EscapeUnescapeTokenTest(tf.test.TestCase):
3544

3645
def test_escape_token(self):
@@ -186,6 +195,24 @@ def test_raises_exception_when_not_encodable(self):
186195
with self.assertRaises(AssertionError):
187196
encoder.encode(original)
188197

198+
def test_load_from_file(self):
199+
# Test a vocab file with words not wrapped with single quotes
200+
encoder = text_encoder.SubwordTextEncoder()
201+
correct_vocab = ['the', 'and', 'of']
202+
vocab = io.StringIO('the\n'
203+
'and\n'
204+
'of\n')
205+
encoder._load_from_file_object(vocab)
206+
self.assertEqual(encoder._all_subtoken_strings, correct_vocab)
207+
208+
# Test a vocab file with words wrapped in single quotes
209+
encoder = text_encoder.SubwordTextEncoder()
210+
vocab = io.StringIO('\'the\'\n'
211+
'\'and\'\n'
212+
'\'of\'\n')
213+
encoder._load_from_file_object(vocab)
214+
self.assertEqual(encoder._all_subtoken_strings, correct_vocab)
215+
189216

190217
if __name__ == '__main__':
191218
tf.test.main()

0 commit comments

Comments
 (0)