Skip to content

Commit 18da50a

Browse files
authored
fix d2s generation bug (#6740)
1 parent 98bfddd commit 18da50a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

paddlenlp/transformers/generation_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,8 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
13981398
unfinished_flag,
13991399
model_kwargs,
14001400
)
1401-
paddle.increment(cur_len)
1402-
paddle.increment(cur_len_gpu)
1401+
paddle.increment(cur_len)
1402+
paddle.increment(cur_len_gpu)
14031403
else:
14041404
while cur_len < max_length and paddle.any(unfinished_flag):
14051405
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
@@ -1411,8 +1411,8 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
14111411
unfinished_flag,
14121412
model_kwargs,
14131413
)
1414-
paddle.increment(cur_len)
1415-
paddle.increment(cur_len_gpu)
1414+
paddle.increment(cur_len)
1415+
paddle.increment(cur_len_gpu)
14161416

14171417
return input_ids[:, origin_len:], scores
14181418

0 commit comments

Comments
 (0)