Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 957b5cd

Browse files
yynilrsepassi
authored andcommitted
Fix two bugs : 1. Add decode_list in ClassLabelEncoder, 2. Fix the bug for language model problem's beam search (#679)
* 1. add decode_list for ClassLabelEncoder 2. For language model problem(with no inputs), fix the wrong beam_size searching in Transformer model. * Make the transformer as the original one
1 parent 518eaba commit 957b5cd

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensor2tensor/data_generators/text_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,12 @@ def encode(self, label_str):
208208

209209
def decode(self, label_id):
210210
if isinstance(label_id, list):
211-
assert len(label_id) == 1
212-
label_id, = label_id
211+
return self._class_labels[label_id[0]]
213212
return self._class_labels[label_id]
214213

214+
def decode_list(self, ids):
215+
return [self._class_labels[i] for i in ids]
216+
215217
@property
216218
def vocab_size(self):
217219
return len(self._class_labels)

tensor2tensor/models/transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,10 @@ def forced_logits():
389389
top_beams=top_beams,
390390
alpha=alpha,
391391
batch_size=batch_size)
392-
if partial_targets is not None:
392+
if partial_targets is not None and beam_size == 1:
393393
ret["outputs"] = ret["outputs"][:, partial_targets_length:]
394+
elif partial_targets is not None and beam_size > 1:
395+
ret["outputs"] = ret["outputs"][:, :,partial_targets_length:]
394396
return ret
395397

396398

0 commit comments

Comments
 (0)