Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6eea0e2

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Add an option to score files to t2t_decoder.
PiperOrigin-RevId: 191769234
1 parent b951c79 commit 6eea0e2

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

tensor2tensor/bin/t2t_decoder.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
# Dependency imports
3838

3939
from tensor2tensor.bin import t2t_trainer
40+
from tensor2tensor.data_generators import text_encoder
4041
from tensor2tensor.utils import decoding
42+
from tensor2tensor.utils import registry
4143
from tensor2tensor.utils import trainer_lib
4244
from tensor2tensor.utils import usr_dir
4345

@@ -59,6 +61,8 @@
5961
flags.DEFINE_bool("decode_interactive", False,
6062
"Interactive local inference mode.")
6163
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
64+
flags.DEFINE_string("score_file", "", "File to score. Each line in the file "
65+
"must be in the format input \t target.")
6266

6367

6468
def create_hparams():
@@ -96,12 +100,80 @@ def decode(estimator, hparams, decode_hp):
96100
dataset_split="test" if FLAGS.eval_use_test_set else None)
97101

98102

103+
def score_file(filename):
104+
"""Score each line in a file and return the scores."""
105+
# Prepare model.
106+
hparams = create_hparams()
107+
encoders = registry.problem(FLAGS.problems).feature_encoders(FLAGS.data_dir)
108+
has_inputs = "inputs" in encoders
109+
110+
# Prepare features for feeding into the model.
111+
if has_inputs:
112+
inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension.
113+
batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D.
114+
targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension.
115+
batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D.
116+
features = {
117+
"inputs": batch_inputs,
118+
"targets": batch_targets,
119+
} if has_inputs else {"targets": batch_targets}
120+
121+
# Prepare the model and the graph when model runs on features.
122+
model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL)
123+
_, losses = model(features)
124+
saver = tf.train.Saver()
125+
126+
with tf.Session() as sess:
127+
# Load weights from checkpoint.
128+
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
129+
ckpt = ckpts.model_checkpoint_path
130+
saver.restore(sess, ckpt)
131+
# Run on each line.
132+
results = []
133+
for line in open(filename):
134+
tab_split = line.split("\t")
135+
if len(tab_split) > 2:
136+
raise ValueError("Each line must have at most one tab separator.")
137+
if len(tab_split) == 1:
138+
targets = tab_split[0].strip()
139+
else:
140+
targets = tab_split[1].strip()
141+
inputs = tab_split[0].strip()
142+
# Run encoders and append EOS symbol.
143+
targets_numpy = encoders["targets"].encode(
144+
targets) + [text_encoder.EOS_ID]
145+
if has_inputs:
146+
inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID]
147+
# Prepare the feed.
148+
feed = {
149+
inputs_ph: inputs_numpy,
150+
targets_ph: targets_numpy
151+
} if has_inputs else {targets_ph: targets_numpy}
152+
# Get the score.
153+
np_loss = sess.run(losses["training"], feed)
154+
results.append(np_loss)
155+
return results
156+
157+
99158
def main(_):
100159
tf.logging.set_verbosity(tf.logging.INFO)
101160
trainer_lib.set_random_seed(FLAGS.random_seed)
102161
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
103162
FLAGS.use_tpu = False # decoding not supported on TPU
104163

164+
if FLAGS.score_file:
165+
filename = os.path.expanduser(FLAGS.score_file)
166+
if not tf.gfile.Exists(filename):
167+
raise ValueError("The file to score doesn't exist: %s" % filename)
168+
results = score_file(filename)
169+
if not FLAGS.decode_to_file:
170+
raise ValueError("To score a file, specify --decode_to_file for results.")
171+
write_file = open(os.path.expanduser(FLAGS.decode_to_file), "w")
172+
for score in results:
173+
write_file.write("%.6f\n" % score)
174+
write_file.close()
175+
return
176+
105177
hp = create_hparams()
106178
decode_hp = create_decode_hparams()
107179

0 commit comments

Comments
 (0)