Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e8ae589

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Support for dictionary losses in model_fn_body to be consistent with model_fn_body_sharded. Also updated inline doc.
PiperOrigin-RevId: 164305140
1 parent 7efdbee commit e8ae589

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,10 @@ def model_fn_body_sharded(self, sharded_features):
469469
_with_timing(self.model_fn_body, "model_fn_body"),
470470
datashard_to_features)
471471
if isinstance(output, tuple):
472-
loss = {"extra": tf.reduce_mean(output[1])}
472+
if isinstance(output[1], dict):
473+
loss = output[1]
474+
else:
475+
loss = {"extra": tf.reduce_mean(output[1])}
473476
output = output[0]
474477
else:
475478
loss = {"extra": 0.0}
@@ -483,10 +486,12 @@ def model_fn_body(self, features):
483486
484487
Args:
485488
features: A dictionary of key to Tensor. Each Tensor has shape
486-
`[batch_size, ?, ?, hidden_size]`.
489+
[batch_size, ?, ?, hidden_size].
487490
488491
Returns:
489-
a `Tensor` of logits with shape `[batch_size, O, P, body_output_size]`.
492+
output: tensor of logits with shape [batch_size, O, P, body_output_size.
493+
losses: either single loss as a scalar, a list, a tensor (to be averaged)
494+
or a dictionary of losses.
490495
"""
491496
raise NotImplementedError("Abstract Method")
492497

0 commit comments

Comments
 (0)