@@ -460,6 +460,11 @@ def top(self, body_output, features):
460
460
Returns:
461
461
logits: dict of str to Tensor, denoting each logits for each target; or
462
462
a single Tensor denoting the logits for that target.
463
+ When targets are generated at training time:
464
+ logits == {
465
+ "self_generated_targets": <generated targets tensor>
466
+ "logits": <original logits Tensor or dict>
467
+ }
463
468
"""
464
469
if isinstance (body_output , dict ):
465
470
if self ._problem_hparams :
@@ -1306,6 +1311,22 @@ def estimator_model_fn(cls,
1306
1311
else :
1307
1312
logits , losses_dict = model (features ) # pylint: disable=not-callable
1308
1313
1314
+ # Support model-generated labels by overriding features["targets"] with
1315
+ # logits["self_generated_targets"].
1316
+ if isinstance (logits , dict ) and "self_generated_targets" in logits :
1317
+ # Overwrite 'features["targets"]' and 'labels'
1318
+ # by logits["self_generated_targets"].
1319
+ tf .logging .info ("Replacing targets with model-provided targets." )
1320
+ features ["targets" ] = labels = logits .pop ("self_generated_targets" )
1321
+ assert logits .keys () == ["logits" ], (
1322
+ # See "Returns" in the "top" method docstring for the expected
1323
+ # "logits" format when targets are generated at training time.
1324
+ "Expect only key 'logits' when there is 'self_generated_targets'. "
1325
+ "Found {}" .format (logits .keys ())
1326
+ )
1327
+ # Recover the original logits tensor from the logits dict.
1328
+ logits = logits ["logits" ] # Can be a tf.Tensor or a dict.
1329
+
1309
1330
# Set known shapes
1310
1331
if common_layers .is_xla_compiled ():
1311
1332
if isinstance (logits , dict ):
0 commit comments