Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit bb08c0f

Browse files
T2T TeamCopybara-Service
authored andcommitted
Allow T2T model to generate labels in-graph.
PiperOrigin-RevId: 219156293
1 parent f1e161c commit bb08c0f

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,11 @@ def top(self, body_output, features):
460460
Returns:
461461
logits: dict of str to Tensor, denoting each logits for each target; or
462462
a single Tensor denoting the logits for that target.
463+
When targets are generated at training time:
464+
logits == {
465+
"self_generated_targets": <generated targets tensor>
466+
"logits": <original logits Tensor or dict>
467+
}
463468
"""
464469
if isinstance(body_output, dict):
465470
if self._problem_hparams:
@@ -1306,6 +1311,22 @@ def estimator_model_fn(cls,
13061311
else:
13071312
logits, losses_dict = model(features) # pylint: disable=not-callable
13081313

1314+
# Support model-generated labels by overriding features["targets"] with
1315+
# logits["self_generated_targets"].
1316+
if isinstance(logits, dict) and "self_generated_targets" in logits:
1317+
# Overwrite 'features["targets"]' and 'labels'
1318+
# by logits["self_generated_targets"].
1319+
tf.logging.info("Replacing targets with model-provided targets.")
1320+
features["targets"] = labels = logits.pop("self_generated_targets")
1321+
assert logits.keys() == ["logits"], (
1322+
# See "Returns" in the "top" method docstring for the expected
1323+
# "logits" format when targets are generated at training time.
1324+
"Expect only key 'logits' when there is 'self_generated_targets'. "
1325+
"Found {}".format(logits.keys())
1326+
)
1327+
# Recover the original logits tensor from the logits dict.
1328+
logits = logits["logits"] # Can be a tf.Tensor or a dict.
1329+
13091330
# Set known shapes
13101331
if common_layers.is_xla_compiled():
13111332
if isinstance(logits, dict):

0 commit comments

Comments
 (0)