@@ -77,6 +77,37 @@ def _get_token_encoder(vocab_dir, vocab_name, filename):
77
77
return text_encoder .TokenTextEncoder (vocab_path )
78
78
79
79
80
+ def _maybe_download_corpus (tmp_dir , vocab_type ):
81
+ """Download and unpack the corpus.
82
+
83
+ Args:
84
+ tmp_dir: directory containing dataset.
85
+ """
86
+ filename = os .path .basename (PTB_URL )
87
+ compressed_filepath = generator_utils .maybe_download (
88
+ tmp_dir , filename , PTB_URL )
89
+ ptb_files = []
90
+ ptb_char_files = []
91
+
92
+ with tarfile .open (compressed_filepath , "r:gz" ) as tgz :
93
+ files = []
94
+ # Selecting only relevant files.
95
+ for m in tgz .getmembers ():
96
+ if "ptb" in m .name and ".txt" in m .name :
97
+ if "char" in m .name :
98
+ ptb_char_files += [m .name ]
99
+ else :
100
+ ptb_files += [m .name ]
101
+ files += [m ]
102
+
103
+ tgz .extractall (tmp_dir , members = files )
104
+
105
+ if vocab_type == text_problems .VocabType .CHARACTER :
106
+ return ptb_char_files
107
+ else :
108
+ return ptb_files
109
+
110
+
80
111
@registry .register_problem
81
112
class LanguagemodelPtb10k (text_problems .Text2SelfProblem ):
82
113
"""PTB, 10k vocab."""
@@ -91,6 +122,10 @@ def dataset_splits(self):
91
122
"shards" : 1 ,
92
123
}]
93
124
125
+ @property
126
+ def is_generate_per_split (self ):
127
+ return True
128
+
94
129
@property
95
130
def vocab_filename (self ):
96
131
return "vocab.lmptb.10000"
@@ -100,28 +135,7 @@ def vocab_type(self):
100
135
return text_problems .VocabType .TOKEN
101
136
102
137
def generate_samples (self , data_dir , tmp_dir , dataset_split ):
103
- filename = os .path .basename (PTB_URL )
104
- compressed_filepath = generator_utils .maybe_download (
105
- tmp_dir , filename , PTB_URL )
106
- ptb_files = []
107
- ptb_char_files = []
108
- with tarfile .open (compressed_filepath , "r:gz" ) as tgz :
109
- files = []
110
- # Selecting only relevant files.
111
- for m in tgz .getmembers ():
112
- if "ptb" in m .name and ".txt" in m .name :
113
- if "char" in m .name :
114
- ptb_char_files += [m .name ]
115
- else :
116
- ptb_files += [m .name ]
117
- files += [m ]
118
-
119
- tgz .extractall (tmp_dir , members = files )
120
-
121
- if self .vocab_type == text_problems .VocabType .CHARACTER :
122
- files = ptb_char_files
123
- else :
124
- files = ptb_files
138
+ files = _maybe_download_corpus (tmp_dir , self .vocab_type )
125
139
126
140
train_file , valid_file = None , None
127
141
for filename in files :
@@ -138,10 +152,13 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
138
152
train = dataset_split == problem .DatasetSplit .TRAIN
139
153
filepath = train_file if train else valid_file
140
154
141
- with tf .gfile .GFile (filepath , "r" ) as f :
142
- for line in f :
143
- line = " " .join (line .replace ("\n " , " %s " % EOS ).split ())
144
- yield {"targets" : line }
155
+ def _generate_samples ():
156
+ with tf .gfile .GFile (filepath , "r" ) as f :
157
+ for line in f :
158
+ line = " " .join (line .replace ("\n " , " %s " % EOS ).split ())
159
+ yield {"targets" : line }
160
+
161
+ return _generate_samples ()
145
162
146
163
147
164
@registry .register_problem
0 commit comments