@@ -120,6 +120,11 @@ def __init__(self,
120
120
self ._create_modalities (self ._problem_hparams , self ._hparams )
121
121
if not common_layers .is_xla_compiled ():
122
122
self .summarize_hparams ()
123
+ self ._variable_scopes = {}
124
+
125
+ def _add_variable_scope (self , key , vs ):
126
+ if key not in self ._variable_scopes :
127
+ self ._variable_scopes [key ] = vs
123
128
124
129
def summarize_hparams (self ):
125
130
def create_hparams_summary (hparams , name ):
@@ -261,15 +266,17 @@ def model_fn_sharded(self, sharded_features):
261
266
return sharded_logits , losses
262
267
263
268
def model_fn (self , features ):
264
- with tf .variable_scope (tf .get_variable_scope (), use_resource = True ):
269
+ with tf .variable_scope (tf .get_variable_scope (), use_resource = True ) as vs :
270
+ self ._add_variable_scope ("model_fn" , vs )
265
271
transformed_features = self .bottom (features )
266
272
267
273
if self .hparams .activation_dtype == "bfloat16" :
268
274
for k , v in sorted (six .iteritems (transformed_features )):
269
275
if v .dtype == tf .float32 :
270
276
transformed_features [k ] = tf .cast (v , tf .bfloat16 )
271
277
272
- with tf .variable_scope ("body" ):
278
+ with tf .variable_scope ("body" ) as body_vs :
279
+ self ._add_variable_scope ("body" , body_vs )
273
280
log_info ("Building model body" )
274
281
body_out = self .body (transformed_features )
275
282
output , losses = self ._normalize_body_output (body_out )
@@ -302,7 +309,8 @@ def bottom(self, features):
302
309
tf .logging .warning ("Missing feature %s - ignoring." % key )
303
310
continue
304
311
do_reuse = input_modality .name in all_previous_modalities
305
- with tf .variable_scope (input_modality .name , reuse = do_reuse ):
312
+ with tf .variable_scope (input_modality .name , reuse = do_reuse ) as im_vs :
313
+ self ._add_variable_scope (input_modality .name , im_vs )
306
314
log_info ("Transforming feature '%s' with %s.bottom" , key ,
307
315
input_modality .name )
308
316
transformed_features [key ] = input_modality .bottom (features [key ])
@@ -313,14 +321,16 @@ def bottom(self, features):
313
321
if isinstance (target_modality , dict ):
314
322
for k , v in six .iteritems (target_modality ):
315
323
if k in features :
316
- with tf .variable_scope (
317
- "%s/%s" % (v .name , k )): # TODO(aidangomez): share variables?
324
+ # TODO(aidangomez): share variables?
325
+ with tf .variable_scope ("%s/%s" % (v .name , k )) as tm_vs :
326
+ self ._add_variable_scope ("%s/%s" % (v .name , k ), tm_vs )
318
327
log_info ("Transforming '%s' with %s.targets_bottom" , k , v .name )
319
328
transformed_features [k ] = v .targets_bottom (features [k ])
320
329
else :
321
330
tf .logging .warn ("Modality not found in features: %s" , k )
322
331
else :
323
- with tf .variable_scope (target_modality .name ):
332
+ with tf .variable_scope (target_modality .name ) as tm_vs :
333
+ self ._add_variable_scope (target_modality .name , tm_vs )
324
334
if "targets" in features :
325
335
log_info ("Transforming 'targets' with %s.targets_bottom" ,
326
336
target_modality .name )
@@ -359,7 +369,8 @@ def _top_single(self, body_output, target_modality, features):
359
369
log_warn ("Without a Problem, T2TModel.top is a passthrough." )
360
370
return body_output
361
371
362
- with tf .variable_scope (target_modality .name ):
372
+ with tf .variable_scope (target_modality .name ) as tm_vs :
373
+ self ._add_variable_scope (tm_vs .name , tm_vs )
363
374
log_info ("Transforming body output with %s.top" , target_modality .name )
364
375
last_only = (
365
376
target_modality .top_is_pointwise and
@@ -401,7 +412,9 @@ def top(self, body_output, features):
401
412
"problem_hparams.target_modality's dict." % k )
402
413
logits = {}
403
414
for k , v in six .iteritems (body_output ):
404
- with tf .variable_scope (k ): # TODO(aidangomez): share variables here?
415
+ # TODO(aidangomez): share variables here?
416
+ with tf .variable_scope (k ) as top_vs :
417
+ self ._add_variable_scope ("top_%s" % k , top_vs )
405
418
logits [k ] = self ._top_single (v , target_modality [k ], features )
406
419
return logits
407
420
else :
@@ -1270,26 +1283,33 @@ def estimator_model_fn(cls,
1270
1283
return model .estimator_spec_train (
1271
1284
loss , num_async_replicas = num_async_replicas , use_tpu = use_tpu )
1272
1285
1286
+ def initialize_from_ckpt (self , ckpt_dir ):
1287
+ model_dir = self ._hparams .get ("model_dir" , None )
1288
+ already_has_ckpt = (
1289
+ model_dir and tf .train .latest_checkpoint (model_dir ) is not None )
1290
+ if already_has_ckpt :
1291
+ return
1292
+
1293
+ # TODO(mitchellstern): Add support for partitioned variables?
1294
+ reader = tf .contrib .framework .load_checkpoint (ckpt_dir )
1295
+ variable_map = {}
1296
+ for var in tf .contrib .framework .get_trainable_variables ():
1297
+ var_name = var .name .split (":" )[0 ]
1298
+ if reader .has_tensor (var_name ):
1299
+ tf .logging .info ("Loading variable from checkpoint: %s" , var_name )
1300
+ variable_map [var_name ] = var
1301
+ else :
1302
+ tf .logging .info (
1303
+ "Cannot find variable in checkpoint, skipping: %s" , var_name )
1304
+ tf .train .init_from_checkpoint (ckpt_dir , variable_map )
1305
+
1273
1306
def estimator_spec_train (self , loss , num_async_replicas = 1 , use_tpu = False ):
1274
1307
"""Construct EstimatorSpec for TRAIN mode."""
1275
1308
train_op = self .optimize (loss , num_async_replicas = num_async_replicas ,
1276
1309
use_tpu = use_tpu )
1277
1310
1278
- # TODO(mitchellstern): Add support for partitioned variables?
1279
- if (tf .train .latest_checkpoint (self ._hparams .model_dir ) is None and
1280
- self ._hparams .pretrained_model_dir ):
1281
- pretrained_model_dir = self ._hparams .pretrained_model_dir
1282
- reader = tf .contrib .framework .load_checkpoint (pretrained_model_dir )
1283
- variable_map = {}
1284
- for var in tf .contrib .framework .get_trainable_variables ():
1285
- var_name = var .name .split (":" )[0 ]
1286
- if reader .has_tensor (var_name ):
1287
- tf .logging .info ("Loading variable from checkpoint: %s" , var_name )
1288
- variable_map [var_name ] = var
1289
- else :
1290
- tf .logging .info (
1291
- "Cannot find variable in checkpoint, skipping: %s" , var_name )
1292
- tf .train .init_from_checkpoint (pretrained_model_dir , variable_map )
1311
+ if self ._hparams .warm_start_from :
1312
+ self .initialize_from_ckpt (self ._hparams .warm_start_from )
1293
1313
1294
1314
if use_tpu :
1295
1315
host_call = _create_host_call (self .hparams .model_dir )
0 commit comments