Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6f1152c

Browse files
author
Ryan Sepassi
committed
TransformerScorer model to only score targets on infer
PiperOrigin-RevId: 192812089
1 parent a7c150e commit 6f1152c

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

tensor2tensor/models/transformer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,59 @@ def is_not_finished(i, finished, *_):
535535
return {"outputs": decoded_ids, "scores": scores}
536536

537537

538+
@registry.register_model
539+
class TransformerScorer(Transformer):
540+
"""Transformer model, but only scores in PREDICT mode.
541+
542+
Checkpoints between Transformer and TransformerScorer are interchangeable.
543+
"""
544+
545+
def __init__(self, *args, **kwargs):
546+
super(TransformerScorer, self).__init__(*args, **kwargs)
547+
self._name = "transformer"
548+
self._base_name = "transformer"
549+
550+
def infer(self,
551+
features=None,
552+
decode_length=50,
553+
beam_size=1,
554+
top_beams=1,
555+
alpha=0.0):
556+
"""Returns the targets and their log probabilities."""
557+
del decode_length, beam_size, top_beams, alpha
558+
assert features is not None
559+
560+
# Run the model
561+
self.hparams.force_full_predict = True
562+
with tf.variable_scope(self.name):
563+
logits, _ = self.model_fn(features)
564+
assert len(logits.shape) == 5 # [batch, time, 1, 1, vocab]
565+
logits = tf.squeeze(logits, [2, 3])
566+
567+
# Compute the log probabilities
568+
log_probs = beam_search.log_prob_from_logits(logits)
569+
570+
# Slice out the log_probs of the targets
571+
targets = features["targets"]
572+
assert len(targets.shape) == 4 # [batch, time, 1, 1]
573+
targets = tf.squeeze(targets, [2, 3])
574+
batch_size, timesteps = common_layers.shape_list(targets)
575+
vocab_size = common_layers.shape_list(log_probs)[-1]
576+
flat_targets = tf.reshape(targets, [batch_size * timesteps])
577+
flat_log_probs = tf.reshape(log_probs, [batch_size * timesteps, vocab_size])
578+
flat_indices = tf.stack(
579+
[tf.range(tf.to_int64(batch_size) * tf.to_int64(timesteps)),
580+
tf.to_int64(flat_targets)], axis=1)
581+
log_probs = tf.reshape(
582+
tf.gather_nd(flat_log_probs, flat_indices),
583+
[batch_size, timesteps])
584+
585+
# Sum over time to get the log_prob of the sequence
586+
scores = tf.reduce_sum(log_probs, axis=1)
587+
588+
return {"outputs": targets, "scores": scores}
589+
590+
538591
@registry.register_model
539592
class TransformerEncoder(t2t_model.T2TModel):
540593
"""Transformer, encoder only."""

tensor2tensor/models/transformer_test.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838
class TransformerTest(tf.test.TestCase):
3939

40-
def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, has_input=True):
40+
def getModel(self, hparams=None, mode=tf.estimator.ModeKeys.TRAIN,
41+
has_input=True, model_cls=transformer.Transformer):
42+
if hparams is None:
43+
hparams = transformer.transformer_tiny()
4144
hparams.hidden_size = 8
4245
hparams.filter_size = 32
4346
hparams.num_heads = 1
@@ -58,7 +61,7 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, has_input=True):
5861
"target_space_id": tf.constant(1, dtype=tf.int32)
5962
}
6063

61-
return transformer.Transformer(hparams, mode, p_hparams), features
64+
return model_cls(hparams, mode, p_hparams), features
6265

6366
def testTransformer(self):
6467
model, features = self.getModel(transformer.transformer_small())
@@ -240,5 +243,47 @@ def testTransformerWithEncoderDecoderAttentionLoss(self):
240243
self.assertEqual(res.shape, ())
241244

242245

246+
class TransformerScorerTest(TransformerTest):
247+
248+
def testReturnsScores(self):
249+
model, features = self.getModel(
250+
mode=tf.estimator.ModeKeys.PREDICT,
251+
model_cls=transformer.TransformerScorer)
252+
infer_out = model.infer(features)
253+
self.assertTrue("outputs" in infer_out)
254+
self.assertTrue("scores" in infer_out)
255+
256+
with self.test_session() as session:
257+
session.run(tf.global_variables_initializer())
258+
infer_out = session.run(infer_out)
259+
self.assertEqual((BATCH_SIZE,), infer_out["scores"].shape)
260+
self.assertEqual((BATCH_SIZE, TARGET_LENGTH), infer_out["outputs"].shape)
261+
262+
def testVarNames(self):
263+
with tf.Graph().as_default():
264+
model, features = self.getModel(
265+
mode=tf.estimator.ModeKeys.PREDICT,
266+
model_cls=transformer.TransformerScorer)
267+
_ = model.infer(features)
268+
scorer_vars = [v.name for v in tf.global_variables()]
269+
270+
with tf.Graph().as_default():
271+
model, features = self.getModel(
272+
mode=tf.estimator.ModeKeys.EVAL,
273+
model_cls=transformer.TransformerScorer)
274+
_ = model(features)
275+
scorer_eval_vars = [v.name for v in tf.global_variables()]
276+
277+
with tf.Graph().as_default():
278+
model, features = self.getModel(
279+
mode=tf.estimator.ModeKeys.EVAL,
280+
model_cls=transformer.Transformer)
281+
_ = model(features)
282+
transformer_vars = [v.name for v in tf.global_variables()]
283+
284+
self.assertEqual(sorted(scorer_vars), sorted(transformer_vars))
285+
self.assertEqual(sorted(scorer_eval_vars), sorted(transformer_vars))
286+
287+
243288
if __name__ == "__main__":
244289
tf.test.main()

0 commit comments

Comments
 (0)