Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a2f1ee9

Browse files
author
Ryan Sepassi
committed
Add features (export, SessionConfig, Parallelism, hooks) to TPU codepath
PiperOrigin-RevId: 179602110
1 parent 2be0cbb commit a2f1ee9

27 files changed

+488
-225
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'tensor2tensor/bin/t2t-datagen',
2424
'tensor2tensor/bin/t2t-decoder',
2525
'tensor2tensor/bin/t2t-make-tf-configs',
26+
'tensor2tensor/bin/t2t-tpu-trainer',
2627
],
2728
install_requires=[
2829
'bz2file',

tensor2tensor/bin/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+

tensor2tensor/bin/t2t-decoder

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,11 @@ flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
5858

5959

6060
def create_hparams():
61-
hparams = tpu_trainer.create_hparams()
62-
hparams.add_hparam("data_dir", os.path.expanduser(FLAGS.data_dir))
63-
tpu_trainer_lib.add_problem_hparams(hparams, FLAGS.problems)
64-
return hparams
61+
return tpu_trainer_lib.create_hparams(
62+
FLAGS.hparams_set,
63+
FLAGS.hparams,
64+
data_dir=os.path.expanduser(FLAGS.data_dir),
65+
problem_name=FLAGS.problems)
6566

6667

6768
def create_decode_hparams():
@@ -90,7 +91,7 @@ def decode(estimator, hparams, decode_hp):
9091
def main(_):
9192
tf.logging.set_verbosity(tf.logging.INFO)
9293
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
93-
FLAGS.use_tpu = False
94+
FLAGS.use_tpu = False # decoding not supported on TPU
9495

9596
hp = create_hparams()
9697
decode_hp = create_decode_hparams()

tensor2tensor/bin/t2t-tpu-trainer

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ from __future__ import division
2020
from __future__ import print_function
2121

2222
import os
23+
import sys
2324

2425
# Dependency imports
2526

2627
from tensor2tensor import models # pylint: disable=unused-import
2728
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
28-
from tensor2tensor.tpu import tpu_trainer_lib as lib
29+
from tensor2tensor.tpu import tpu_trainer_lib
30+
from tensor2tensor.utils import decoding
2931
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
3032
from tensor2tensor.utils import registry
3133
from tensor2tensor.utils import usr_dir
@@ -45,7 +47,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
4547
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4648
flags.DEFINE_integer("iterations_per_loop", 1000,
4749
"Number of iterations in a TPU training loop.")
48-
flags.DEFINE_bool("use_tpu", True, "Whether to use TPU.")
50+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU.")
4951

5052
# To maintain compatibility with some internal libs, we guard against these flag
5153
# definitions possibly erroring. Apologies for the ugliness.
@@ -66,38 +68,66 @@ def get_problem_name():
6668

6769

6870
def create_hparams():
69-
hparams = registry.hparams(FLAGS.hparams_set)()
70-
if FLAGS.hparams:
71-
hparams = hparams.parse(FLAGS.hparams)
72-
return hparams
71+
return tpu_trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams)
7372

7473

7574
def create_experiment_fn():
76-
return lib.create_experiment_fn(
75+
use_validation_monitor = (FLAGS.schedule in
76+
["train_and_evaluate", "continuous_train_and_eval"]
77+
and FLAGS.local_eval_frequency)
78+
return tpu_trainer_lib.create_experiment_fn(
7779
FLAGS.model,
7880
get_problem_name(),
7981
os.path.expanduser(FLAGS.data_dir),
8082
FLAGS.train_steps,
8183
FLAGS.eval_steps,
8284
FLAGS.local_eval_frequency,
8385
FLAGS.schedule,
86+
export=FLAGS.export_saved_model,
87+
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
88+
use_tfdbg=FLAGS.tfdbg,
89+
use_dbgprofile=FLAGS.dbgprofile,
90+
use_validation_monitor=use_validation_monitor,
91+
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
92+
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
93+
eval_early_stopping_metric_minimize=FLAGS.
94+
eval_early_stopping_metric_minimize,
8495
use_tpu=FLAGS.use_tpu)
8596

8697

87-
def create_run_config():
88-
return lib.create_run_config(
98+
def create_run_config(hp):
99+
return tpu_trainer_lib.create_run_config(
89100
model_dir=os.path.expanduser(FLAGS.output_dir),
90101
master=FLAGS.master,
91102
iterations_per_loop=FLAGS.iterations_per_loop,
92103
num_shards=FLAGS.tpu_num_shards,
93104
log_device_placement=FLAGS.log_device_placement,
94105
save_checkpoints_steps=max(FLAGS.iterations_per_loop,
95106
FLAGS.local_eval_frequency),
107+
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
108+
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
96109
num_gpus=FLAGS.worker_gpu,
97110
gpu_order=FLAGS.gpu_order,
98111
shard_to_cpu=FLAGS.locally_shard_to_cpu,
99112
num_async_replicas=FLAGS.worker_replicas,
100-
use_tpu=FLAGS.use_tpu)
113+
gpu_mem_fraction=FLAGS.worker_gpu_memory_fraction,
114+
enable_graph_rewriter=FLAGS.experimental_optimize_placement,
115+
use_tpu=FLAGS.use_tpu,
116+
schedule=FLAGS.schedule,
117+
no_data_parallelism=hp.no_data_parallelism,
118+
daisy_chain_variables=hp.daisy_chain_variables,
119+
ps_replicas=FLAGS.ps_replicas,
120+
ps_job=FLAGS.ps_job,
121+
ps_gpu=FLAGS.ps_gpu,
122+
sync=FLAGS.sync,
123+
worker_id=FLAGS.worker_id,
124+
worker_job=FLAGS.worker_job)
125+
126+
127+
def log_registry():
128+
if FLAGS.registry_help:
129+
tf.logging.info(registry.help_string())
130+
sys.exit(0)
101131

102132

103133
def execute_schedule(exp):
@@ -111,9 +141,13 @@ def main(_):
111141
tf.logging.set_verbosity(tf.logging.INFO)
112142
tf.set_random_seed(123)
113143
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
144+
log_registry()
145+
146+
hparams = create_hparams()
147+
run_config = create_run_config(hparams)
114148

115149
exp_fn = create_experiment_fn()
116-
exp = exp_fn(create_run_config(), create_hparams())
150+
exp = exp_fn(run_config, hparams)
117151
execute_schedule(exp)
118152

119153

tensor2tensor/bin/t2t-trainer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ To train your model, for example:
2626
--model=transformer
2727
--hparams_set=transformer_base
2828
"""
29+
# DEPRECATED
2930
from __future__ import absolute_import
3031
from __future__ import division
3132
from __future__ import print_function

tensor2tensor/bin/t2t_decoder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@
5757

5858

5959
def create_hparams():
60-
hparams = tpu_trainer.create_hparams()
61-
hparams.add_hparam("data_dir", os.path.expanduser(FLAGS.data_dir))
62-
tpu_trainer_lib.add_problem_hparams(hparams, FLAGS.problems)
63-
return hparams
60+
return tpu_trainer_lib.create_hparams(
61+
FLAGS.hparams_set,
62+
FLAGS.hparams,
63+
data_dir=os.path.expanduser(FLAGS.data_dir),
64+
problem_name=FLAGS.problems)
6465

6566

6667
def create_decode_hparams():
@@ -89,7 +90,7 @@ def decode(estimator, hparams, decode_hp):
8990
def main(_):
9091
tf.logging.set_verbosity(tf.logging.INFO)
9192
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
92-
FLAGS.use_tpu = False
93+
FLAGS.use_tpu = False # decoding not supported on TPU
9394

9495
hp = create_hparams()
9596
decode_hp = create_decode_hparams()

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
--model=transformer
2626
--hparams_set=transformer_base
2727
"""
28+
# DEPRECATED
2829
from __future__ import absolute_import
2930
from __future__ import division
3031
from __future__ import print_function

tensor2tensor/data_generators/problem.py

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,6 @@ def dataset(self,
383383
# Construct the Problem's hparams so that items within it are accessible
384384
_ = self.get_hparams(hparams)
385385

386-
data_fields, data_items_to_decoders = self.example_reading_spec()
387-
if data_items_to_decoders is None:
388-
data_items_to_decoders = {
389-
field: tf.contrib.slim.tfexample_decoder.Tensor(field)
390-
for field in data_fields
391-
}
392-
393386
is_training = mode == tf.estimator.ModeKeys.TRAIN
394387
data_filepattern = self.filepattern(data_dir, dataset_split, shard=shard)
395388
tf.logging.info("Reading data files from %s", data_filepattern)
@@ -406,22 +399,13 @@ def dataset(self,
406399
else:
407400
dataset = tf.data.TFRecordDataset(data_files)
408401

409-
def decode_record(record):
410-
"""Serialized Example to dict of <feature name, Tensor>."""
411-
decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(
412-
data_fields, data_items_to_decoders)
413-
414-
decode_items = list(data_items_to_decoders)
415-
decoded = decoder.decode(record, items=decode_items)
416-
return dict(zip(decode_items, decoded))
417-
418402
def _preprocess(example):
419403
example = self.preprocess_example(example, mode, hparams)
420404
self.maybe_reverse_features(example)
421405
self.maybe_copy_features(example)
422406
return example
423407

424-
dataset = dataset.map(decode_record, num_parallel_calls=num_threads)
408+
dataset = dataset.map(self.decode_example, num_parallel_calls=num_threads)
425409

426410
if preprocess:
427411
dataset = dataset.map(_preprocess, num_parallel_calls=num_threads)
@@ -430,6 +414,22 @@ def _preprocess(example):
430414

431415
return dataset
432416

417+
def decode_example(self, serialized_example):
418+
"""Return a dict of Tensors from a serialized tensorflow.Example."""
419+
data_fields, data_items_to_decoders = self.example_reading_spec()
420+
if data_items_to_decoders is None:
421+
data_items_to_decoders = {
422+
field: tf.contrib.slim.tfexample_decoder.Tensor(field)
423+
for field in data_fields
424+
}
425+
426+
decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(
427+
data_fields, data_items_to_decoders)
428+
429+
decode_items = list(data_items_to_decoders)
430+
decoded = decoder.decode(serialized_example, items=decode_items)
431+
return dict(zip(decode_items, decoded))
432+
433433
@property
434434
def has_inputs(self):
435435
return "inputs" in self.get_feature_encoders()
@@ -496,7 +496,8 @@ def input_fn(self, mode, hparams, params=None, config=None,
496496
mode: tf.estimator.ModeKeys
497497
hparams: HParams, model hparams
498498
params: dict, may include "batch_size"
499-
config: RunConfig; if passed, should include t2t_device_info dict
499+
config: RunConfig; should have the data_parallelism attribute if not using
500+
TPU
500501
dataset_kwargs: dict, if passed, will pass as kwargs to self.dataset
501502
method when called
502503
@@ -521,29 +522,8 @@ def gpu_valid_size(example):
521522
hparams.max_length if drop_long_sequences else 10**9)
522523

523524
def define_shapes(example):
524-
"""Set the right shapes for the features."""
525-
inputs = example["inputs"]
526-
targets = example["targets"]
527-
528-
# Ensure inputs and targets are proper rank.
529-
while len(inputs.get_shape()) < 4:
530-
inputs = tf.expand_dims(inputs, axis=-1)
531-
while len(targets.get_shape()) < 4:
532-
targets = tf.expand_dims(targets, axis=-1)
533-
534-
example["inputs"] = inputs
535-
example["targets"] = targets
536-
537-
if config.use_tpu:
538-
# Ensure batch size is set on all features
539-
for _, t in six.iteritems(example):
540-
shape = t.get_shape().as_list()
541-
shape[0] = params["batch_size"]
542-
t.set_shape(t.get_shape().merge_with(shape))
543-
# Assert shapes are fully known
544-
t.get_shape().assert_is_fully_defined()
545-
546-
return example
525+
return _standardize_shapes(
526+
example, batch_size=(config.use_tpu and params["batch_size"]))
547527

548528
# Read and preprocess
549529
data_dir = hparams.data_dir
@@ -569,7 +549,7 @@ def define_shapes(example):
569549
dataset = dataset.apply(
570550
tf.contrib.data.batch_and_drop_remainder(tpu_batch_size))
571551
else:
572-
num_shards = config.t2t_device_info["num_shards"]
552+
num_shards = config.data_parallelism.n
573553
dataset = dataset.batch(hparams.batch_size * num_shards)
574554
else:
575555
# Variable length features
@@ -586,7 +566,7 @@ def define_shapes(example):
586566
dataset = dataset.filter(gpu_valid_size)
587567
batching_scheme = data_reader.hparams_to_batching_scheme(
588568
hparams,
589-
shard_multiplier=config.t2t_device_info["num_shards"],
569+
shard_multiplier=config.data_parallelism.n,
590570
length_multiplier=self.get_hparams().batch_size_multiplier)
591571
if hparams.use_fixed_batch_size:
592572
batching_scheme["batch_sizes"] = [hparams.batch_size]
@@ -601,7 +581,7 @@ def define_shapes(example):
601581
dataset = dataset.prefetch(1)
602582
features = dataset.make_one_shot_iterator().get_next()
603583
if not config.use_tpu:
604-
_summarize_features(features, config.t2t_device_info["num_shards"])
584+
_summarize_features(features, config.data_parallelism.n)
605585

606586
if mode == tf.estimator.ModeKeys.PREDICT:
607587
features["infer_targets"] = features["targets"]
@@ -614,6 +594,25 @@ def define_shapes(example):
614594

615595
return features, features["targets"]
616596

597+
def serving_input_fn(self, hparams):
598+
"""Input fn for serving export, starting from serialized example."""
599+
mode = tf.estimator.ModeKeys.PREDICT
600+
serialized_example = tf.placeholder(
601+
dtype=tf.string, shape=[None], name="serialized_example")
602+
dataset = tf.data.Dataset.from_tensor_slices(serialized_example)
603+
dataset = dataset.map(self.decode_example)
604+
dataset = dataset.map(lambda ex: self.preprocess_example(ex, mode, hparams))
605+
dataset = dataset.map(data_reader.cast_int64_to_int32)
606+
dataset = dataset.padded_batch(1000, dataset.output_shapes)
607+
dataset = dataset.map(_standardize_shapes)
608+
features = tf.contrib.data.get_single_element(dataset)
609+
610+
if self.has_inputs:
611+
features.pop("targets", None)
612+
613+
return tf.estimator.export.ServingInputReceiver(
614+
features=features, receiver_tensors=serialized_example)
615+
617616

618617
class FeatureInfo(object):
619618

@@ -907,3 +906,28 @@ def _summarize_features(features, num_shards=1):
907906
tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens)
908907
tf.summary.scalar("%s_nonpadding_fraction" % k,
909908
tf.reduce_mean(nonpadding))
909+
910+
911+
def _standardize_shapes(features, batch_size=None):
912+
"""Set the right shapes for the features."""
913+
914+
for fname in ["inputs", "targets"]:
915+
if fname not in features:
916+
continue
917+
918+
f = features[fname]
919+
while len(f.get_shape()) < 4:
920+
f = tf.expand_dims(f, axis=-1)
921+
922+
features[fname] = f
923+
924+
if batch_size:
925+
# Ensure batch size is set on all features
926+
for _, t in six.iteritems(features):
927+
shape = t.get_shape().as_list()
928+
shape[0] = batch_size
929+
t.set_shape(t.get_shape().merge_with(shape))
930+
# Assert shapes are fully known
931+
t.get_shape().assert_is_fully_defined()
932+
933+
return features

tensor2tensor/data_generators/translate_enzh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
_ENZH_TEST_DATASETS = [[
5151
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
52-
("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm")
52+
("dev/newsdev2017-enzh-src.en.sgm", "dev/newsdev2017-enzh-ref.zh.sgm")
5353
]]
5454

5555

0 commit comments

Comments
 (0)