Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8277f50

Browse files
Ryan SepassiCopybara-Service
authored andcommitted
Store variable scopes in T2TModel; add T2TModel.initialize_from_ckpt
PiperOrigin-RevId: 209218783
1 parent 837990f commit 8277f50

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def create_experiment_fn(**kwargs):
180180
use_tpu=FLAGS.use_tpu,
181181
use_tpu_estimator=FLAGS.use_tpu_estimator,
182182
use_xla=FLAGS.xla_compile,
183+
warm_start_from=FLAGS.warm_start_from,
183184
**kwargs)
184185

185186

@@ -214,7 +215,6 @@ def create_run_config(hp):
214215
hp.weight_dtype == "float32")
215216
return trainer_lib.create_run_config(
216217
model_dir=os.path.expanduser(FLAGS.output_dir),
217-
warm_start_from=FLAGS.warm_start_from,
218218
master=FLAGS.master,
219219
iterations_per_loop=FLAGS.iterations_per_loop,
220220
num_shards=FLAGS.tpu_num_shards,

tensor2tensor/utils/t2t_model.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def __init__(self,
120120
self._create_modalities(self._problem_hparams, self._hparams)
121121
if not common_layers.is_xla_compiled():
122122
self.summarize_hparams()
123+
self._variable_scopes = {}
124+
125+
def _add_variable_scope(self, key, vs):
126+
if key not in self._variable_scopes:
127+
self._variable_scopes[key] = vs
123128

124129
def summarize_hparams(self):
125130
def create_hparams_summary(hparams, name):
@@ -261,15 +266,17 @@ def model_fn_sharded(self, sharded_features):
261266
return sharded_logits, losses
262267

263268
def model_fn(self, features):
264-
with tf.variable_scope(tf.get_variable_scope(), use_resource=True):
269+
with tf.variable_scope(tf.get_variable_scope(), use_resource=True) as vs:
270+
self._add_variable_scope("model_fn", vs)
265271
transformed_features = self.bottom(features)
266272

267273
if self.hparams.activation_dtype == "bfloat16":
268274
for k, v in sorted(six.iteritems(transformed_features)):
269275
if v.dtype == tf.float32:
270276
transformed_features[k] = tf.cast(v, tf.bfloat16)
271277

272-
with tf.variable_scope("body"):
278+
with tf.variable_scope("body") as body_vs:
279+
self._add_variable_scope("body", body_vs)
273280
log_info("Building model body")
274281
body_out = self.body(transformed_features)
275282
output, losses = self._normalize_body_output(body_out)
@@ -302,7 +309,8 @@ def bottom(self, features):
302309
tf.logging.warning("Missing feature %s - ignoring." % key)
303310
continue
304311
do_reuse = input_modality.name in all_previous_modalities
305-
with tf.variable_scope(input_modality.name, reuse=do_reuse):
312+
with tf.variable_scope(input_modality.name, reuse=do_reuse) as im_vs:
313+
self._add_variable_scope(input_modality.name, im_vs)
306314
log_info("Transforming feature '%s' with %s.bottom", key,
307315
input_modality.name)
308316
transformed_features[key] = input_modality.bottom(features[key])
@@ -313,14 +321,16 @@ def bottom(self, features):
313321
if isinstance(target_modality, dict):
314322
for k, v in six.iteritems(target_modality):
315323
if k in features:
316-
with tf.variable_scope(
317-
"%s/%s" % (v.name, k)): # TODO(aidangomez): share variables?
324+
# TODO(aidangomez): share variables?
325+
with tf.variable_scope("%s/%s" % (v.name, k)) as tm_vs:
326+
self._add_variable_scope("%s/%s" % (v.name, k), tm_vs)
318327
log_info("Transforming '%s' with %s.targets_bottom", k, v.name)
319328
transformed_features[k] = v.targets_bottom(features[k])
320329
else:
321330
tf.logging.warn("Modality not found in features: %s", k)
322331
else:
323-
with tf.variable_scope(target_modality.name):
332+
with tf.variable_scope(target_modality.name) as tm_vs:
333+
self._add_variable_scope(target_modality.name, tm_vs)
324334
if "targets" in features:
325335
log_info("Transforming 'targets' with %s.targets_bottom",
326336
target_modality.name)
@@ -359,7 +369,8 @@ def _top_single(self, body_output, target_modality, features):
359369
log_warn("Without a Problem, T2TModel.top is a passthrough.")
360370
return body_output
361371

362-
with tf.variable_scope(target_modality.name):
372+
with tf.variable_scope(target_modality.name) as tm_vs:
373+
self._add_variable_scope(tm_vs.name, tm_vs)
363374
log_info("Transforming body output with %s.top", target_modality.name)
364375
last_only = (
365376
target_modality.top_is_pointwise and
@@ -401,7 +412,9 @@ def top(self, body_output, features):
401412
"problem_hparams.target_modality's dict." % k)
402413
logits = {}
403414
for k, v in six.iteritems(body_output):
404-
with tf.variable_scope(k): # TODO(aidangomez): share variables here?
415+
# TODO(aidangomez): share variables here?
416+
with tf.variable_scope(k) as top_vs:
417+
self._add_variable_scope("top_%s" % k, top_vs)
405418
logits[k] = self._top_single(v, target_modality[k], features)
406419
return logits
407420
else:
@@ -1270,26 +1283,33 @@ def estimator_model_fn(cls,
12701283
return model.estimator_spec_train(
12711284
loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu)
12721285

1286+
def initialize_from_ckpt(self, ckpt_dir):
1287+
model_dir = self._hparams.get("model_dir", None)
1288+
already_has_ckpt = (
1289+
model_dir and tf.train.latest_checkpoint(model_dir) is not None)
1290+
if already_has_ckpt:
1291+
return
1292+
1293+
# TODO(mitchellstern): Add support for partitioned variables?
1294+
reader = tf.contrib.framework.load_checkpoint(ckpt_dir)
1295+
variable_map = {}
1296+
for var in tf.contrib.framework.get_trainable_variables():
1297+
var_name = var.name.split(":")[0]
1298+
if reader.has_tensor(var_name):
1299+
tf.logging.info("Loading variable from checkpoint: %s", var_name)
1300+
variable_map[var_name] = var
1301+
else:
1302+
tf.logging.info(
1303+
"Cannot find variable in checkpoint, skipping: %s", var_name)
1304+
tf.train.init_from_checkpoint(ckpt_dir, variable_map)
1305+
12731306
def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
12741307
"""Construct EstimatorSpec for TRAIN mode."""
12751308
train_op = self.optimize(loss, num_async_replicas=num_async_replicas,
12761309
use_tpu=use_tpu)
12771310

1278-
# TODO(mitchellstern): Add support for partitioned variables?
1279-
if (tf.train.latest_checkpoint(self._hparams.model_dir) is None and
1280-
self._hparams.pretrained_model_dir):
1281-
pretrained_model_dir = self._hparams.pretrained_model_dir
1282-
reader = tf.contrib.framework.load_checkpoint(pretrained_model_dir)
1283-
variable_map = {}
1284-
for var in tf.contrib.framework.get_trainable_variables():
1285-
var_name = var.name.split(":")[0]
1286-
if reader.has_tensor(var_name):
1287-
tf.logging.info("Loading variable from checkpoint: %s", var_name)
1288-
variable_map[var_name] = var
1289-
else:
1290-
tf.logging.info(
1291-
"Cannot find variable in checkpoint, skipping: %s", var_name)
1292-
tf.train.init_from_checkpoint(pretrained_model_dir, variable_map)
1311+
if self._hparams.warm_start_from:
1312+
self.initialize_from_ckpt(self._hparams.warm_start_from)
12931313

12941314
if use_tpu:
12951315
host_call = _create_host_call(self.hparams.model_dir)

tensor2tensor/utils/trainer_lib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def is_cloud_async_distributed():
105105

106106
def create_run_config(master="",
107107
model_dir=None,
108-
warm_start_from=None,
109108
iterations_per_loop=1000,
110109
num_shards=8,
111110
log_device_placement=False,
@@ -197,7 +196,6 @@ def create_run_config(master="",
197196
del run_config_args["evaluation_master"]
198197

199198
config = run_config_cls(**run_config_args)
200-
config.warm_start_from = warm_start_from
201199

202200
# If not using TPU, add device info for data_parallelism
203201
config.use_tpu = use_tpu
@@ -259,7 +257,6 @@ def create_estimator(model_name,
259257
model_fn=model_fn,
260258
model_dir=run_config.model_dir,
261259
config=run_config,
262-
warm_start_from=run_config.warm_start_from
263260
)
264261
return estimator
265262

@@ -432,14 +429,16 @@ def create_experiment(
432429
use_tpu_estimator=False,
433430
use_xla=False,
434431
additional_train_hooks=None,
435-
additional_eval_hooks=None):
432+
additional_eval_hooks=None,
433+
warm_start_from=None):
436434
"""Create Experiment."""
437435
# HParams
438436
hparams.add_hparam("model_dir", run_config.model_dir)
439437
hparams.add_hparam("data_dir", data_dir)
440438
hparams.add_hparam("train_steps", train_steps)
441439
hparams.add_hparam("eval_steps", eval_steps)
442440
hparams.add_hparam("schedule", schedule)
441+
hparams.add_hparam("warm_start_from", warm_start_from)
443442
add_problem_hparams(hparams, problem_name)
444443

445444
# Estimator

0 commit comments

Comments
 (0)