Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f715f85

Browse files
author
Ryan Sepassi
committed
Separate CLI t2t_decoder
PiperOrigin-RevId: 166920562
1 parent a3be70a commit f715f85

File tree

5 files changed

+123
-34
lines changed

5 files changed

+123
-34
lines changed

README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,26 @@ You can chat with us and other users on
2929
with T2T announcements.
3030

3131
Here is a one-command version that installs tensor2tensor, downloads the data,
32-
trains an English-German translation model, and lets you use it interactively:
32+
trains an English-German translation model, and evaluates it:
3333
```
3434
pip install tensor2tensor && t2t-trainer \
3535
--generate_data \
3636
--data_dir=~/t2t_data \
3737
--problems=translate_ende_wmt32k \
3838
--model=transformer \
3939
--hparams_set=transformer_base_single_gpu \
40-
--output_dir=~/t2t_train/base \
40+
--output_dir=~/t2t_train/base
41+
```
42+
43+
You can decode from the model interactively:
44+
45+
```
46+
t2t-decoder \
47+
--data_dir=~/t2t_data \
48+
--problems=translate_ende_wmt32k \
49+
--model=transformer \
50+
--hparams_set=transformer_base_single_gpu \
51+
--output_dir=~/t2t_train/base
4152
--decode_interactive
4253
```
4354

@@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE
106117
BEAM_SIZE=4
107118
ALPHA=0.6
108119
109-
t2t-trainer \
120+
t2t-decoder \
110121
--data_dir=$DATA_DIR \
111122
--problems=$PROBLEM \
112123
--model=$MODEL \
113124
--hparams_set=$HPARAMS \
114125
--output_dir=$TRAIN_DIR \
115-
--train_steps=0 \
116-
--eval_steps=0 \
117126
--decode_beam_size=$BEAM_SIZE \
118127
--decode_alpha=$ALPHA \
119128
--decode_from_file=$DECODE_FILE

docs/walkthrough.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

1212
Here is a one-command version that installs tensor2tensor, downloads the data,
13-
trains an English-German translation model, and lets you use it interactively:
13+
trains an English-German translation model, and evaluates it:
1414
```
1515
pip install tensor2tensor && t2t-trainer \
1616
--generate_data \
1717
--data_dir=~/t2t_data \
1818
--problems=translate_ende_wmt32k \
1919
--model=transformer \
2020
--hparams_set=transformer_base_single_gpu \
21-
--output_dir=~/t2t_train/base \
21+
--output_dir=~/t2t_train/base
22+
```
23+
24+
You can decode from the model interactively:
25+
26+
```
27+
t2t-decoder \
28+
--data_dir=~/t2t_data \
29+
--problems=translate_ende_wmt32k \
30+
--model=transformer \
31+
--hparams_set=transformer_base_single_gpu \
32+
--output_dir=~/t2t_train/base
2233
--decode_interactive
2334
```
2435

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
scripts=[
2323
'tensor2tensor/bin/t2t-trainer',
2424
'tensor2tensor/bin/t2t-datagen',
25+
'tensor2tensor/bin/t2t-decoder',
2526
'tensor2tensor/bin/t2t-make-tf-configs',
2627
],
2728
install_requires=[

tensor2tensor/bin/t2t-decoder

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2017 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
r"""Decode from trained T2T models.
18+
19+
This binary performs inference using the Estimator API.
20+
21+
Example usage to decode from dataset:
22+
23+
t2t-decoder \
24+
--data_dir ~/data \
25+
--problems=algorithmic_identity_binary40 \
26+
--model=transformer
27+
--hparams_set=transformer_base
28+
29+
Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode
30+
sources.
31+
"""
32+
from __future__ import absolute_import
33+
from __future__ import division
34+
from __future__ import print_function
35+
36+
import os
37+
38+
# Dependency imports
39+
40+
from tensor2tensor.utils import decoding
41+
from tensor2tensor.utils import trainer_utils
42+
from tensor2tensor.utils import usr_dir
43+
44+
import tensorflow as tf
45+
46+
flags = tf.flags
47+
FLAGS = flags.FLAGS
48+
49+
flags.DEFINE_string("t2t_usr_dir", "",
50+
"Path to a Python module that will be imported. The "
51+
"__init__.py file should include the necessary imports. "
52+
"The imported files should contain registrations, "
53+
"e.g. @registry.register_model calls, that will then be "
54+
"available to the t2t-decoder.")
55+
56+
57+
def main(_):
58+
tf.logging.set_verbosity(tf.logging.INFO)
59+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
60+
trainer_utils.log_registry()
61+
trainer_utils.validate_flags()
62+
data_dir = os.path.expanduser(FLAGS.data_dir)
63+
output_dir = os.path.expanduser(FLAGS.output_dir)
64+
65+
hparams = trainer_utils.create_hparams(
66+
FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams)
67+
estimator, _ = trainer_utils.create_experiment_components(
68+
hparams=hparams,
69+
output_dir=output_dir,
70+
data_dir=data_dir,
71+
model_name=FLAGS.model)
72+
73+
if FLAGS.decode_interactive:
74+
decoding.decode_interactively(estimator)
75+
elif FLAGS.decode_from_file:
76+
decoding.decode_from_file(estimator, FLAGS.decode_from_file)
77+
else:
78+
decoding.decode_from_dataset(
79+
estimator,
80+
FLAGS.problems.split("-"),
81+
return_beams=FLAGS.decode_return_beams,
82+
beam_size=FLAGS.decode_beam_size,
83+
max_predictions=FLAGS.decode_num_samples,
84+
decode_to_file=FLAGS.decode_to_file,
85+
save_images=FLAGS.decode_save_images,
86+
identity_output=FLAGS.identity_output)
87+
88+
89+
if __name__ == "__main__":
90+
tf.app.run()

tensor2tensor/utils/trainer_utils.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tensor2tensor.data_generators import problem_hparams
2828
from tensor2tensor.models import models # pylint: disable=unused-import
2929
from tensor2tensor.utils import data_reader
30-
from tensor2tensor.utils import decoding
3130
from tensor2tensor.utils import devices
3231
from tensor2tensor.utils import input_fn_builder
3332
from tensor2tensor.utils import metrics
@@ -101,16 +100,13 @@
101100
flags.DEFINE_string("ps_job", "/job:ps", "name of ps job")
102101
flags.DEFINE_integer("ps_replicas", 0, "How many ps replicas.")
103102

104-
# Decode flags
105-
# Set one of {decode_from_dataset, decode_interactive, decode_from_file} to
106-
# decode.
107-
flags.DEFINE_bool("decode_from_dataset", False, "Decode from dataset on disk.")
108-
flags.DEFINE_bool("decode_use_last_position_only", False,
109-
"In inference, use last position only for speedup.")
103+
# Decoding flags
104+
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
110105
flags.DEFINE_bool("decode_interactive", False,
111106
"Interactive local inference mode.")
107+
flags.DEFINE_bool("decode_use_last_position_only", False,
108+
"In inference, use last position only for speedup.")
112109
flags.DEFINE_bool("decode_save_images", False, "Save inference input images.")
113-
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
114110
flags.DEFINE_string("decode_to_file", None, "Path to inference output file")
115111
flags.DEFINE_integer("decode_shards", 1, "How many shards to decode.")
116112
flags.DEFINE_integer("decode_problem_id", 0, "Which problem to decode.")
@@ -128,7 +124,7 @@
128124
"Maximum number of ids in input. Or <= 0 for no max.")
129125
flags.DEFINE_bool("identity_output", False, "To print the output as identity")
130126
flags.DEFINE_integer("decode_num_samples", -1,
131-
"Number of samples to decode. Currently used in"
127+
"Number of samples to decode. Currently used in "
132128
"decode_from_dataset. Use -1 for all.")
133129

134130

@@ -303,7 +299,6 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule):
303299
if exp.train_steps > 0 or exp.eval_steps > 0:
304300
tf.logging.info("Performing local training and evaluation.")
305301
exp.train_and_evaluate()
306-
decode(exp.estimator)
307302
else:
308303
# Perform distributed training/evaluation.
309304
learn_runner.run(
@@ -350,20 +345,3 @@ def session_config():
350345

351346
def get_data_filepatterns(data_dir, mode):
352347
return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode)
353-
354-
355-
def decode(estimator):
356-
if FLAGS.decode_interactive:
357-
decoding.decode_interactively(estimator)
358-
elif FLAGS.decode_from_file is not None and FLAGS.decode_from_file is not "":
359-
decoding.decode_from_file(estimator, FLAGS.decode_from_file)
360-
elif FLAGS.decode_from_dataset:
361-
decoding.decode_from_dataset(
362-
estimator,
363-
FLAGS.problems.split("-"),
364-
return_beams=FLAGS.decode_return_beams,
365-
beam_size=FLAGS.decode_beam_size,
366-
max_predictions=FLAGS.decode_num_samples,
367-
decode_to_file=FLAGS.decode_to_file,
368-
save_images=FLAGS.decode_save_images,
369-
identity_output=FLAGS.identity_output)

0 commit comments

Comments
 (0)