Skip to content

Commit 6fc691d

Browse files
committed
add INFERENCE_TRUNCATED_RETURN_EOS
1 parent 85df59d commit 6fc691d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

llm/predict/predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,8 @@ def predict_dy_insert(
13381338
repeat_num=1,
13391339
**kwargs
13401340
):
1341+
# NOTE(gongenlei): The output of the ultra-long truncation does not return an eos_token
1342+
os.environ["INFERENCE_TRUNCATED_RETURN_EOS"] = "0"
13411343
assert repeat_num >= 1
13421344
flag_current_rank_run = self.tensor_parallel_rank == 0 or all_rank_return
13431345
self.input_ids = []

paddlenlp/rl/utils/infer_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import copy
1818
import inspect
19-
import os
2019
from contextlib import contextmanager
2120

2221
import paddle
@@ -94,8 +93,7 @@ def predict(self, input_ids: paddle.Tensor = None, repeat_num=1, **kwargs):
9493
for row in input_ids:
9594
row_ids = process_row(row, remove_value=self.tokenizer.pad_token_id, remove_side="left").tolist()
9695
input_ids_list.append(row_ids)
97-
# NOTE(gongenlei): The output of the ultra-long truncation does not return an eos_token
98-
os.environ["INFERENCE_TRUNCATED_RETURN_EOS"] = "0"
96+
9997
if self.config.dynamic_insert:
10098
outputs = self.predict_dy_insert(
10199
input_ids=input_ids_list,

0 commit comments

Comments
 (0)