Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b99fac2

Browse files
authored
Merge pull request #678 from martinpopel/transforer-relative-decode-fix
Fix transformer decoding when using attention other than dot_product.
2 parents 957b5cd + 64c7f6a commit b99fac2

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tensor2tensor/models/transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
225225
None if using greedy decoding (beam_size=1)
226226
}
227227
"""
228+
if self._hparams.self_attention_type != "dot_product":
229+
# Caching is not guaranteed to work with attention types other than
230+
# dot_product.
231+
# TODO(petershaw): Support fast decoding when using relative
232+
# position representations, i.e. "dot_product_relative" attention.
233+
return self._beam_decode_slow(features, decode_length, beam_size,
234+
top_beams, alpha)
228235
with tf.variable_scope(self.name):
229236
return self._fast_decode(
230237
features, decode_length, beam_size, top_beams, alpha)

0 commit comments

Comments
 (0)