|
18 | 18 | from __future__ import division
|
19 | 19 | from __future__ import print_function
|
20 | 20 |
|
| 21 | +import os |
| 22 | + |
21 | 23 | # Dependency imports
|
22 | 24 |
|
23 |
| -from tensor2tensor.data_generators import generator_utils as utils |
| 25 | +from tensor2tensor.data_generators import generator_utils |
24 | 26 | from tensor2tensor.data_generators import text_encoder
|
25 | 27 | from tensor2tensor.utils import metrics
|
| 28 | +from tensor2tensor.utils import registry |
26 | 29 |
|
27 | 30 | import tensorflow as tf
|
28 | 31 |
|
@@ -176,20 +179,23 @@ def eval_metrics(self):
|
176 | 179 | def training_filepaths(self, data_dir, num_shards, shuffled):
|
177 | 180 | file_basename = self.dataset_filename()
|
178 | 181 | if not shuffled:
|
179 |
| - file_basename += utils.UNSHUFFLED_SUFFIX |
180 |
| - return utils.train_data_filenames(file_basename, data_dir, num_shards) |
| 182 | + file_basename += generator_utils.UNSHUFFLED_SUFFIX |
| 183 | + return generator_utils.train_data_filenames( |
| 184 | + file_basename, data_dir, num_shards) |
181 | 185 |
|
182 | 186 | def dev_filepaths(self, data_dir, num_shards, shuffled):
|
183 | 187 | file_basename = self.dataset_filename()
|
184 | 188 | if not shuffled:
|
185 |
| - file_basename += utils.UNSHUFFLED_SUFFIX |
186 |
| - return utils.dev_data_filenames(file_basename, data_dir, num_shards) |
| 189 | + file_basename += generator_utils.UNSHUFFLED_SUFFIX |
| 190 | + return generator_utils.dev_data_filenames( |
| 191 | + file_basename, data_dir, num_shards) |
187 | 192 |
|
188 | 193 | def test_filepaths(self, data_dir, num_shards, shuffled):
|
189 | 194 | file_basename = self.dataset_filename()
|
190 | 195 | if not shuffled:
|
191 |
| - file_basename += utils.UNSHUFFLED_SUFFIX |
192 |
| - return utils.test_data_filenames(file_basename, data_dir, num_shards) |
| 196 | + file_basename += generator_utils.UNSHUFFLED_SUFFIX |
| 197 | + return generator_utils.test_data_filenames( |
| 198 | + file_basename, data_dir, num_shards) |
193 | 199 |
|
194 | 200 | def __init__(self, was_reversed=False, was_copy=False):
|
195 | 201 | """Create a Problem.
|
@@ -323,3 +329,97 @@ def _default_hparams():
|
323 | 329 | # class.
|
324 | 330 | input_space_id=SpaceID.GENERIC,
|
325 | 331 | target_space_id=SpaceID.GENERIC)
|
| 332 | + |
| 333 | + |
| 334 | +class Text2TextProblem(Problem): |
| 335 | + """Base class for text-to-text problems.""" |
| 336 | + |
| 337 | + @property |
| 338 | + def is_character_level(self): |
| 339 | + raise NotImplementedError() |
| 340 | + |
| 341 | + @property |
| 342 | + def targeted_vocab_size(self): |
| 343 | + raise NotImplementedError() # Not needed if self.is_character_level. |
| 344 | + |
| 345 | + def train_generator(self, data_dir, tmp_dir, is_training): |
| 346 | + """Generator of the training data.""" |
| 347 | + raise NotImplementedError() |
| 348 | + |
| 349 | + def dev_generator(self, data_dir, tmp_dir): |
| 350 | + """Generator of the development data.""" |
| 351 | + return self.train_generator(data_dir, tmp_dir, False) |
| 352 | + |
| 353 | + @property |
| 354 | + def input_space_id(self): |
| 355 | + raise NotImplementedError() |
| 356 | + |
| 357 | + @property |
| 358 | + def target_space_id(self): |
| 359 | + raise NotImplementedError() |
| 360 | + |
| 361 | + @property |
| 362 | + def num_shards(self): |
| 363 | + raise NotImplementedError() |
| 364 | + |
| 365 | + @property |
| 366 | + def vocab_name(self): |
| 367 | + raise NotImplementedError() |
| 368 | + |
| 369 | + @property |
| 370 | + def vocab_file(self): |
| 371 | + return "%s.%d" % (self.vocab_name, self.targeted_vocab_size) |
| 372 | + |
| 373 | + @property |
| 374 | + def use_subword_tokenizer(self): |
| 375 | + raise NotImplementedError() |
| 376 | + |
| 377 | + @property |
| 378 | + def has_inputs(self): |
| 379 | + return True # Set to False for language models. |
| 380 | + |
| 381 | + def generate_data(self, data_dir, tmp_dir, task_id=-1): |
| 382 | + generator_utils.generate_dataset_and_shuffle( |
| 383 | + self.train_generator(data_dir, tmp_dir, True), |
| 384 | + self.training_filepaths(data_dir, self.num_shards, shuffled=False), |
| 385 | + self.dev_generator(data_dir, tmp_dir), |
| 386 | + self.dev_filepaths(data_dir, 1, shuffled=False)) |
| 387 | + |
| 388 | + def feature_encoders(self, data_dir): |
| 389 | + vocab_filename = os.path.join(data_dir, self.vocab_file) |
| 390 | + if self.is_character_level: |
| 391 | + encoder = text_encoder.ByteTextEncoder(), |
| 392 | + elif self.use_subword_tokenizer: |
| 393 | + encoder = text_encoder.SubwordTextEncoder(vocab_filename) |
| 394 | + else: |
| 395 | + encoder = text_encoder.TokenTextEncoder(vocab_filename) |
| 396 | + if self.has_inputs: |
| 397 | + return {"inputs": encoder, "targets": encoder} |
| 398 | + return {"targets": encoder} |
| 399 | + |
| 400 | + def hparams(self, defaults, unused_model_hparams): |
| 401 | + p = defaults |
| 402 | + if self.is_character_level: |
| 403 | + source_vocab_size = 256 |
| 404 | + target_vocab_size = 256 |
| 405 | + else: |
| 406 | + target_vocab_size = self._encoders["targets"].vocab_size |
| 407 | + if self.has_inputs: |
| 408 | + source_vocab_size = self._encoders["inputs"].vocab_size |
| 409 | + |
| 410 | + if self.has_inputs: |
| 411 | + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, |
| 412 | + source_vocab_size)} |
| 413 | + p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) |
| 414 | + if self.has_inputs: |
| 415 | + p.input_space_id = self.input_space_id |
| 416 | + p.target_space_id = self.target_space_id |
| 417 | + if self.is_character_level: |
| 418 | + p.loss_multiplier = 2.0 |
| 419 | + |
| 420 | + def eval_metrics(self): |
| 421 | + return [ |
| 422 | + metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5, |
| 423 | + metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY, |
| 424 | + metrics.Metrics.APPROX_BLEU |
| 425 | + ] |
0 commit comments