Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a8ee62a

Browse files
author
Ryan Sepassi
committed
Add IMDB sentiment classification dataset
PiperOrigin-RevId: 166905238
1 parent a2cf057 commit a8ee62a

File tree

6 files changed

+150
-25
lines changed

6 files changed

+150
-25
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensor2tensor.data_generators import desc2code
2727
from tensor2tensor.data_generators import ice_parsing
2828
from tensor2tensor.data_generators import image
29+
from tensor2tensor.data_generators import imdb
2930
from tensor2tensor.data_generators import lm1b
3031
from tensor2tensor.data_generators import ptb
3132
from tensor2tensor.data_generators import snli

tensor2tensor/data_generators/image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ def hparams(self, defaults, model_hparams):
272272
small_modality = "%s:small_image_modality" % registry.Modalities.IMAGE
273273
modality = small_modality if self.is_small else registry.Modalities.IMAGE
274274
p.input_modality = {"inputs": (modality, None)}
275-
p.target_modality = (registry.Modalities.CLASS_LABEL, self.num_classes)
275+
p.target_modality = ("%s:2d" % registry.Modalities.CLASS_LABEL,
276+
self.num_classes)
276277
p.batch_size_multiplier = 4 if self.is_small else 256
277278
p.max_expected_batch_size_per_shard = 8 if self.is_small else 2
278279
p.loss_multiplier = 3.0 if self.is_small else 1.0

tensor2tensor/data_generators/imdb.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""IMDB Sentiment Classification Problem."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import tarfile
24+
25+
# Dependency imports
26+
27+
from tensor2tensor.data_generators import generator_utils
28+
from tensor2tensor.data_generators import problem
29+
from tensor2tensor.data_generators import text_encoder
30+
from tensor2tensor.utils import registry
31+
32+
import tensorflow as tf
33+
34+
# End-of-sentence marker.
35+
EOS = text_encoder.EOS_ID
36+
37+
38+
@registry.register_problem
39+
class SentimentIMDB(problem.Problem):
40+
"""IMDB sentiment classification."""
41+
URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
42+
43+
@property
44+
def num_shards(self):
45+
return 10
46+
47+
@property
48+
def vocab_file(self):
49+
return "sentiment_imdb.vocab"
50+
51+
@property
52+
def targeted_vocab_size(self):
53+
return 2**15
54+
55+
def doc_generator(self, imdb_dir, dataset, include_label=False):
56+
dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join(
57+
imdb_dir, dataset, "neg"), False)]
58+
59+
for d, label in dirs:
60+
for filename in os.listdir(d):
61+
with tf.gfile.Open(os.path.join(d, filename)) as imdb_f:
62+
doc = imdb_f.read().strip()
63+
if include_label:
64+
yield doc, label
65+
else:
66+
yield doc
67+
68+
def generator(self, data_dir, tmp_dir, train):
69+
"""Generate examples."""
70+
# Download and extract
71+
compressed_filename = os.path.basename(self.URL)
72+
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
73+
self.URL)
74+
imdb_dir = os.path.join(tmp_dir, "aclImdb")
75+
if not tf.gfile.Exists(imdb_dir):
76+
with tarfile.open(download_path, "r:gz") as tar:
77+
tar.extractall(tmp_dir)
78+
79+
# Generate vocab
80+
encoder = generator_utils.get_or_generate_vocab_inner(
81+
data_dir, self.vocab_file, self.targeted_vocab_size,
82+
lambda: self.doc_generator(imdb_dir, "train"))
83+
84+
# Generate examples
85+
dataset = "train" if train else "test"
86+
for doc, label in self.doc_generator(imdb_dir, dataset, include_label=True):
87+
yield {
88+
"inputs": encoder.encode(doc) + [EOS],
89+
"targets": [int(label)],
90+
}
91+
92+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
93+
train_paths = self.training_filepaths(
94+
data_dir, self.num_shards, shuffled=False)
95+
dev_paths = self.dev_filepaths(data_dir, 1, shuffled=False)
96+
generator_utils.generate_dataset_and_shuffle(
97+
self.generator(data_dir, tmp_dir, True), train_paths,
98+
self.generator(data_dir, tmp_dir, False), dev_paths)
99+
100+
def hparams(self, defaults, model_hparams):
101+
p = defaults
102+
source_vocab_size = self._encoders["inputs"].vocab_size
103+
p.input_modality = {
104+
"inputs": (registry.Modalities.SYMBOL, source_vocab_size)
105+
}
106+
p.target_modality = (registry.Modalities.CLASS_LABEL, 2)
107+
p.input_space_id = problem.SpaceID.EN_TOK
108+
p.target_space_id = problem.SpaceID.GENERIC
109+
110+
def feature_encoders(self, data_dir):
111+
vocab_filename = os.path.join(data_dir, self.vocab_file)
112+
encoder = text_encoder.SubwordTextEncoder(vocab_filename)
113+
return {
114+
"inputs": encoder,
115+
"targets": text_encoder.TextEncoder(),
116+
}
117+
118+
def example_reading_spec(self):
119+
data_fields = {
120+
"inputs": tf.VarLenFeature(tf.int64),
121+
"targets": tf.FixedLenFeature([1], tf.int64),
122+
}
123+
data_items_to_decoders = None
124+
return (data_fields, data_items_to_decoders)

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def default_problem_hparams():
147147
# Modalities used to map from input features to a space compatible with
148148
# chosen model architecture. One modality spec (which is a 2-tuple,
149149
# (modality_full_name, vocab_size)) per feature key. modality_full_name is
150-
# a string type:name, e.g. class_label:class_label_2d. Leaving off the
151-
# name uses the default modality for that type (e.g. class_label ==
150+
# a string type:name, e.g. class_label:2d. Leaving off the name uses the
151+
# default modality for that type (e.g. class_label ==
152152
# class_label:default).
153153
input_modality={},
154154

tensor2tensor/layers/modalities.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,9 @@ def xnet_resblock(x, filters, res_relu, name):
361361
"compress_block_final")
362362

363363

364-
@registry.register_class_label_modality("default")
364+
@registry.register_class_label_modality("2d")
365365
class ClassLabelModality(modality.Modality):
366-
"""Used for label data."""
366+
"""Used for label data; if is2d=True, uses Xception flow to logits."""
367367

368368
def __init__(self, model_hparams, vocab_size, is2d=True):
369369
super(ClassLabelModality, self).__init__(model_hparams, vocab_size)
@@ -397,9 +397,11 @@ def targets_bottom(self, x):
397397
def top(self, body_output, _):
398398
"""Transform inputs from model space to target space.
399399
400-
Perform the Xception "Exit flow", consisting of a single residual block and
401-
two separable convolutional upscalings followed by global spatial average
402-
pooling.
400+
If instantiated with is2d=True, perform the Xception "Exit flow", consisting
401+
of a single residual block and two separable convolutional upscalings
402+
followed by global spatial average pooling.
403+
404+
Otherwise, a single linear layer to logits.
403405
404406
Args:
405407
body_output: A Tensor with shape [batch, ?, ?, body_output_size].
@@ -417,11 +419,12 @@ def top(self, body_output, _):
417419
spatial_dim = tf.to_int32(spatial_dim_float)
418420
x_depth = int(x.get_shape()[3])
419421
x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth])
420-
x = common_layers.conv_block_downsample(x, self._kernel, self._strides,
421-
self._padding)
422-
x = tf.nn.relu(x)
423-
x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True)
424-
res = common_layers.conv(x, self._vocab_size, (1, 1))
422+
x = common_layers.conv_block_downsample(x, self._kernel, self._strides,
423+
self._padding)
424+
x = tf.nn.relu(x)
425+
x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True)
426+
427+
res = tf.layers.dense(x, self._vocab_size)
425428
return tf.expand_dims(res, 3)
426429

427430
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
@@ -431,7 +434,7 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
431434
top_out, targets, weights_fn=weights_fn)
432435

433436

434-
@registry.register_class_label_modality("class_label_2d")
437+
@registry.register_class_label_modality("default")
435438
class ClassLabel1DModality(ClassLabelModality):
436439
"""Used for label data."""
437440

tensor2tensor/utils/model_builder.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,6 @@ def model_fn(features, targets, mode):
164164
features = _interactive_input_tensor_to_features_dict(features, my_hp)
165165
elif FLAGS.decode_from_file:
166166
features = _decode_input_tensor_to_features_dict(features, my_hp)
167-
# A dictionary containing:
168-
# - problem_choice: A Tensor containing an integer indicating which problem
169-
# was selected for this run.
170-
# - predictions: A Tensor containing the model's output predictions.
171-
run_info = dict()
172-
run_info["problem_choice"] = features["problem_choice"]
173167

174168
if targets is not None:
175169
features["targets"] = targets
@@ -299,11 +293,13 @@ def nth_model(n):
299293

300294
sharded_logits, total_loss = result_list[1:], result_list[0]
301295
if mode == tf.contrib.learn.ModeKeys.EVAL:
302-
logits = tf.concat(sharded_logits, 0)
303296
# For evaluation, return the logits layer as our predictions.
304-
run_info["predictions"] = logits
305-
train_op = None
306-
return run_info, total_loss, None
297+
logits = tf.concat(sharded_logits, 0)
298+
ret = {
299+
"predictions": logits,
300+
"problem_choice": features["problem_choice"],
301+
}
302+
return ret, total_loss, None
307303

308304
assert mode == tf.contrib.learn.ModeKeys.TRAIN
309305

@@ -385,7 +381,7 @@ def nth_model(n):
385381
del summaries[i]
386382

387383
tf.logging.info("Global model_fn finished.")
388-
return run_info, total_loss, train_op
384+
return {"problem_choice": features["problem_choice"]}, total_loss, train_op
389385

390386
return model_fn
391387

0 commit comments

Comments
 (0)