Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8353ef2

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Finish LM1B transfer to Problem, add CNN+DailyMail dataset, style corrections.
PiperOrigin-RevId: 166918589
1 parent 5bf1e82 commit 8353ef2

File tree

9 files changed

+206
-98
lines changed

9 files changed

+206
-98
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ from tensor2tensor.data_generators import algorithmic_math
4242
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
4343
from tensor2tensor.data_generators import audio
4444
from tensor2tensor.data_generators import generator_utils
45-
from tensor2tensor.data_generators import lm1b
4645
from tensor2tensor.data_generators import snli
4746
from tensor2tensor.data_generators import wmt
4847
from tensor2tensor.data_generators import wsj_parsing
@@ -92,14 +91,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
9291
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
9392
lambda: wsj_parsing.parsing_token_generator(
9493
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
95-
"languagemodel_1b32k": (
96-
lambda: lm1b.generator(FLAGS.tmp_dir, True),
97-
lambda: lm1b.generator(FLAGS.tmp_dir, False)
98-
),
99-
"languagemodel_1b_characters": (
100-
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
101-
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
102-
),
10394
"inference_snli32k": (
10495
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
10596
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensor2tensor.data_generators import algorithmic_math
2424
from tensor2tensor.data_generators import audio
2525
from tensor2tensor.data_generators import cipher
26+
from tensor2tensor.data_generators import cnn_dailymail
2627
from tensor2tensor.data_generators import desc2code
2728
from tensor2tensor.data_generators import ice_parsing
2829
from tensor2tensor.data_generators import image
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for the CNN and Daily Mail datasets."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import tarfile
24+
25+
# Dependency imports
26+
27+
import six
28+
from tensor2tensor.data_generators import generator_utils
29+
from tensor2tensor.data_generators import problem
30+
from tensor2tensor.data_generators import text_encoder
31+
from tensor2tensor.utils import registry
32+
33+
import tensorflow as tf
34+
35+
36+
# Links to data from http://cs.nyu.edu/~kcho/DMQA/
37+
_CNN_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ"
38+
39+
_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
40+
41+
42+
# End-of-sentence marker.
43+
EOS = text_encoder.EOS_ID
44+
45+
46+
def _maybe_download_corpora(tmp_dir):
47+
"""Download corpora if necessary and unzip them.
48+
49+
Args:
50+
tmp_dir: directory containing dataset.
51+
52+
Returns:
53+
filepath of the downloaded corpus file.
54+
"""
55+
cnn_filename = "cnn_stories.tgz"
56+
dailymail_filename = "dailymail_stories.tgz"
57+
cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/")
58+
dailymail_finalpath = os.path.join(tmp_dir, "dailymail/stories/")
59+
if not tf.gfile.Exists(cnn_finalpath):
60+
cnn_file = generator_utils.maybe_download_from_drive(
61+
tmp_dir, cnn_filename, _CNN_STORIES_DRIVE_URL)
62+
with tarfile.open(cnn_file, "r:gz") as cnn_tar:
63+
cnn_tar.extractall(tmp_dir)
64+
if not tf.gfile.Exists(dailymail_finalpath):
65+
dailymail_file = generator_utils.maybe_download_from_drive(
66+
tmp_dir, dailymail_filename, _CNN_STORIES_DRIVE_URL)
67+
with tarfile.open(dailymail_file, "r:gz") as dailymail_tar:
68+
dailymail_tar.extractall(tmp_dir)
69+
return [cnn_finalpath, dailymail_finalpath]
70+
71+
72+
def story_generator(tmp_dir):
73+
paths = _maybe_download_corpora(tmp_dir)
74+
for path in paths:
75+
for story_file in tf.gfile.Glob(path + "*"):
76+
story = u""
77+
for line in tf.gfile.Open(story_file):
78+
line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8")
79+
story += line
80+
yield story
81+
82+
83+
def _story_summary_split(story):
84+
end_pos = story.find("\n\n") # Upto first empty line.
85+
assert end_pos != -1
86+
return story[:end_pos], story[end_pos:].strip()
87+
88+
89+
@registry.register_problem
90+
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
91+
"""Summarize CNN and Daily Mail articles to their first paragraph."""
92+
93+
@property
94+
def is_character_level(self):
95+
return False
96+
97+
@property
98+
def has_inputs(self):
99+
return True
100+
101+
@property
102+
def input_space_id(self):
103+
return problem.SpaceID.EN_TOK
104+
105+
@property
106+
def target_space_id(self):
107+
return problem.SpaceID.EN_TOK
108+
109+
@property
110+
def num_shards(self):
111+
return 100
112+
113+
@property
114+
def vocab_name(self):
115+
return "vocab.cnndailymail"
116+
117+
@property
118+
def use_subword_tokenizer(self):
119+
return True
120+
121+
@property
122+
def targeted_vocab_size(self):
123+
return 2**15 # 32768
124+
125+
@property
126+
def use_train_shards_for_dev(self):
127+
return True
128+
129+
def generator(self, data_dir, tmp_dir, _):
130+
encoder = generator_utils.get_or_generate_vocab_inner(
131+
data_dir, self.vocab_file, self.targeted_vocab_size,
132+
lambda: story_generator(tmp_dir))
133+
for story in story_generator(tmp_dir):
134+
summary, rest = _story_summary_split(story)
135+
encoded_summary = encoder.encode(summary) + [EOS]
136+
encoded_story = encoder.encode(rest) + [EOS]
137+
yield {"inputs": encoded_story, "targets": encoded_summary}

tensor2tensor/data_generators/imdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def vocab_file(self):
5050

5151
@property
5252
def targeted_vocab_size(self):
53-
return 2**15
53+
return 2**13 # 8k vocab suffices for this small dataset.
5454

5555
def doc_generator(self, imdb_dir, dataset, include_label=False):
5656
dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join(

tensor2tensor/data_generators/lm1b.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def _get_or_build_subword_text_encoder(tmp_dir):
142142
return ret
143143

144144

145-
@registry.register_problem("languagemodel_1b32k")
146-
class LanguagemodelLm1b(problem.Text2TextProblem):
147-
"""A language model on full English Wikipedia."""
145+
@registry.register_problem
146+
class LanguagemodelLm1b32k(problem.Text2TextProblem):
147+
"""A language model on the 1B words corpus."""
148148

149149
@property
150150
def is_character_level(self):
@@ -156,6 +156,8 @@ def has_inputs(self):
156156

157157
@property
158158
def input_space_id(self):
159+
# Ratio of dev tokens (including eos) to dev words (including eos)
160+
# 176884 / 159658 = 1.107893; multiply ppx by this to compare results.
159161
return problem.SpaceID.EN_TOK
160162

161163
@property
@@ -164,11 +166,11 @@ def target_space_id(self):
164166

165167
@property
166168
def num_shards(self):
167-
return 10
169+
return 100
168170

169171
@property
170172
def vocab_name(self):
171-
return "vocab-2016-09-10.txt.en"
173+
return "vocab.lm1b.en"
172174

173175
@property
174176
def use_subword_tokenizer(self):
@@ -208,3 +210,12 @@ def generator(self, tmp_dir, train, characters=False):
208210
_replace_oov(original_vocab, text_encoder.native_to_unicode(line)))
209211
tokens.append(EOS)
210212
yield {"inputs": [0], "targets": tokens}
213+
214+
215+
@registry.register_problem
216+
class LanguagemodelLm1bCharacters(LanguagemodelLm1b32k):
217+
"""A language model on the 1B words corpus, character level."""
218+
219+
@property
220+
def is_character_level(self):
221+
return True

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -267,35 +267,6 @@ def audio_timit_tokens(model_hparams, wrong_vocab_size):
267267
return p
268268

269269

270-
def lm1b_32k(model_hparams):
271-
"""Billion-word language-modeling benchmark, 32k subword vocabulary."""
272-
p = default_problem_hparams()
273-
# ratio of dev tokens (including eos) to dev words (including eos)
274-
# 176884 / 159658 = 1.107893
275-
p.perplexity_exponent = 1.107893
276-
p.input_modality = {}
277-
encoder = text_encoder.SubwordTextEncoder(
278-
os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
279-
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
280-
p.vocabulary = {"targets": encoder}
281-
p.target_space_id = 3
282-
return p
283-
284-
285-
def lm1b_characters(unused_model_hparams):
286-
"""Billion-word language-modeling benchmark, 32k subword vocabulary."""
287-
p = default_problem_hparams()
288-
# ratio of dev tokens (including eos) to dev words (including eos)
289-
# 826189 / 159658 = 5.174742
290-
p.perplexity_exponent = 5.174742
291-
p.input_modality = {}
292-
encoder = text_encoder.ByteTextEncoder()
293-
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
294-
p.vocabulary = {"targets": encoder}
295-
p.target_space_id = 2
296-
return p
297-
298-
299270
def wmt_parsing_characters(model_hparams):
300271
"""English to parse tree translation benchmark."""
301272
del model_hparams # Unused.
@@ -404,10 +375,6 @@ def img2img_imagenet(unused_model_hparams):
404375
lambda p: audio_timit_tokens(p, 2**13),
405376
"audio_timit_tokens_8k_test":
406377
lambda p: audio_timit_tokens(p, 2**13),
407-
"languagemodel_1b_characters":
408-
lm1b_characters,
409-
"languagemodel_1b32k":
410-
lm1b_32k,
411378
"parsing_english_ptb8k":
412379
lambda p: wmt_parsing_tokens(p, 2**13),
413380
"parsing_english_ptb16k":

tensor2tensor/data_generators/text_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,8 @@ def _load_from_file_object(self, f):
657657
for line in f:
658658
s = line.strip()
659659
# Some vocab files wrap words in single quotes, but others don't
660-
if (len(s) > 1 and ((s.startswith("'") and s.endswith("'")) or
661-
(s.startswith("\"") and s.endswith("\"")))):
660+
if ((s.startswith("'") and s.endswith("'")) or
661+
(s.startswith("\"") and s.endswith("\""))):
662662
s = s[1:-1]
663663
subtoken_strings.append(native_to_unicode(s))
664664
self._init_subtokens_from_list(subtoken_strings)

0 commit comments

Comments
 (0)