Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b6a57e7

Browse files
martinpopelrsepassi
authored andcommitted
t2t-decoder with --checkpoint_path (#524)
* no pip download progress bars in Travis log see #523 * allow specifying --checkpoint_path with t2t-decoder and allow keeping timestamp in that case. This is needed for t2t-translate-all + t2t-bleu to work as expected (I forgot to add this commit to #488). * prevent tf.gfile.Glob crashes due to concurrent filesystem edits tf.gfile.Glob may crash with tensorflow.python.framework.errors_impl.NotFoundError: xy/model.ckpt-1130761_temp_9cb4cb0b0f5f4382b5ea947aadfb7a40; No such file or directory Let's use standard glob.glob instead, it seems to be more reliable. * reintroducing FLAGS deleted by someone this is needed for **locals() to work * speedup BLEU tokenization As I think about it, I would prefer my original implementation https://github.com/tensorflow/tensor2tensor/blob/bb1173adce940e62c840970fa0f06f69fd9398db/tensor2tensor/utils/bleu_hook.py#L147-L157 But it seems there are some T2T/Google internal Python guidelines forbidding this, so we have to live with the singleton. * another solution of #523 * make save_checkpoints_secs work again The functionality was broken during the adoption of TPU trainer_lib.py instead of the original trainer_utils.py. Currently, the default is to save checkpoints each 2000 steps, while in previous T2T versions the default was each 10 minutes. * adapt according to @rsepassi's review * Update NotFoundError
1 parent 02b903b commit b6a57e7

File tree

5 files changed

+33
-8
lines changed

5 files changed

+33
-8
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ before_install:
66
- sudo apt-get update -qq
77
- sudo apt-get install -qq libhdf5-dev
88
install:
9-
- pip install tensorflow
10-
- pip install .[tests]
9+
- pip install -q tensorflow
10+
- pip install -q .[tests]
1111
env:
1212
global:
1313
- T2T_PROBLEM=algorithmic_reverse_binary40_test

tensor2tensor/bin/t2t_decoder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@
4747
FLAGS = flags.FLAGS
4848

4949
# Additional flags in bin/t2t_trainer.py and utils/flags.py
50+
flags.DEFINE_string("checkpoint_path", None,
51+
"Path to the model checkpoint. Overrides output_dir.")
5052
flags.DEFINE_string("decode_from_file", None,
5153
"Path to the source file for decoding")
5254
flags.DEFINE_string("decode_to_file", None,
5355
"Path to the decoded (output) file")
56+
flags.DEFINE_bool("keep_timestamp", True,
57+
"Set the mtime of the decoded file to the checkpoint_path+'.index' mtime.")
5458
flags.DEFINE_bool("decode_interactive", False,
5559
"Interactive local inference mode.")
5660
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
@@ -76,7 +80,11 @@ def decode(estimator, hparams, decode_hp):
7680
decoding.decode_interactively(estimator, hparams, decode_hp)
7781
elif FLAGS.decode_from_file:
7882
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
79-
decode_hp, FLAGS.decode_to_file)
83+
decode_hp, FLAGS.decode_to_file,
84+
checkpoint_path=FLAGS.checkpoint_path)
85+
if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
86+
ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + '.index')
87+
os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
8088
else:
8189
decoding.decode_from_dataset(
8290
estimator,

tensor2tensor/bin/t2t_translate_all.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def main(_):
8181
if not os.path.exists(flags_path):
8282
shutil.copy2(os.path.join(model_dir, "flags.txt"), flags_path)
8383

84+
locals_and_flags = {'FLAGS': FLAGS}
8485
for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes,
8586
FLAGS.min_steps):
8687
tf.logging.info("Translating " + model.filename)
8788
out_file = translated_base_file + "-" + str(model.steps)
89+
locals_and_flags.update(locals())
8890
if os.path.exists(out_file):
8991
tf.logging.info(out_file + " already exists, so skipping it.")
9092
else:
@@ -96,7 +98,7 @@ def main(_):
9698
"--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} "
9799
"--checkpoint_path={model.filename} --decode_from_file={source} "
98100
"--decode_to_file={out_file}"
99-
).format(**locals())
101+
).format(**locals_and_flags)
100102
command = FLAGS.decoder_command.format(**locals())
101103
tf.logging.info("Running:\n" + command)
102104
os.system(command)

tensor2tensor/utils/bleu_hook.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
import sys
2626
import time
27+
import glob
2728
import unicodedata
2829

2930
# Dependency imports
@@ -158,6 +159,7 @@ def property_chars(self, prefix):
158159
return "".join(six.unichr(x) for x in range(sys.maxunicode)
159160
if unicodedata.category(six.unichr(x)).startswith(prefix))
160161

162+
uregex = UnicodeRegex()
161163

162164
def bleu_tokenize(string):
163165
r"""Tokenize a string following the official BLEU implementation.
@@ -183,7 +185,6 @@ def bleu_tokenize(string):
183185
Returns:
184186
a list of tokens
185187
"""
186-
uregex = UnicodeRegex()
187188
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
188189
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
189190
string = uregex.symbol_re.sub(r" \1 ", string)
@@ -205,11 +206,24 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
205206

206207
StepFile = collections.namedtuple("StepFile", "filename mtime ctime steps")
207208

209+
def _try_twice_tf_glob(pattern):
210+
"""tf.gfile.Glob may crash with
211+
tensorflow.python.framework.errors_impl.NotFoundError:
212+
xy/model.ckpt-1130761_temp_9cb4cb0b0f5f4382b5ea947aadfb7a40;
213+
No such file or directory
214+
215+
Standard glob.glob does not have this bug, but does not hangle gs://...
216+
So let's use tf.gfile.Glob twice to handle most concurrency problems.
217+
"""
218+
try:
219+
return tf.gfile.Glob(pattern)
220+
except tf.errors.NotFoundError:
221+
return tf.gfile.Glob(pattern)
208222

209223
def _read_stepfiles_list(path_prefix, path_suffix=".index", min_steps=0):
210224
"""Return list of StepFiles sorted by step from files at path_prefix."""
211225
stepfiles = []
212-
for filename in tf.gfile.Glob(path_prefix + "*-[0-9]*" + path_suffix):
226+
for filename in _try_twice_tf_glob(path_prefix + '*-[0-9]*' + path_suffix):
213227
basename = filename[:-len(path_suffix)] if len(path_suffix) else filename
214228
try:
215229
steps = int(basename.rsplit("-")[-1])

tensor2tensor/utils/decoding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def decode_from_file(estimator,
219219
filename,
220220
hparams,
221221
decode_hp,
222-
decode_to_file=None):
222+
decode_to_file=None,
223+
checkpoint_path=None):
223224
"""Compute predictions on entries in filename and write them out."""
224225
if not decode_hp.batch_size:
225226
decode_hp.batch_size = 32
@@ -248,7 +249,7 @@ def input_fn():
248249
return _decode_input_tensor_to_features_dict(example, hparams)
249250

250251
decodes = []
251-
result_iter = estimator.predict(input_fn)
252+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
252253
for result in result_iter:
253254
if decode_hp.return_beams:
254255
beam_decodes = []

0 commit comments

Comments
 (0)