Skip to content

Commit 99e1c10

Browse files
authored
[infer] update_input_v2 op will not return eos when env_var is set (#10628)
* Revert "output_len truncated without eos_token_id (#10614)" This reverts commit 850c6c2. * check * check * add env var
1 parent 7dbba98 commit 99e1c10

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

csrc/gpu/update_inputs_v2.cu

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ __global__ void update_inputs_kernel_v2(
4242
const int bsz,
4343
const int max_bsz,
4444
const int input_ids_stride,
45-
const int end_length) {
45+
const int end_length,
46+
const int Flag_truncated_return_eos) {
4647
int thread_idx = threadIdx.x;
48+
bool output_len_truncated = false;
4749
// update step_idx and stop_flags
4850
if (thread_idx < max_bsz) {
4951
bool stop_flag = stop_flags[thread_idx];
@@ -52,6 +54,7 @@ __global__ void update_inputs_kernel_v2(
5254
}
5355
if (step_idx[thread_idx] >= max_dec_len[thread_idx]) {
5456
stop_flags[thread_idx] = true;
57+
output_len_truncated = true;
5558
}
5659
}
5760
__syncthreads();
@@ -60,11 +63,15 @@ __global__ void update_inputs_kernel_v2(
6063
if (stop_flags[thread_idx]) {
6164
if (seq_lens_this_time[thread_idx] == 0) {
6265
next_tokens[thread_idx] = -1;
66+
} else {
67+
if (!Flag_truncated_return_eos && output_len_truncated) {
68+
// output len truncated will not return eos for rl.
69+
kwargs_next_tokens[thread_idx] = next_tokens[thread_idx];
70+
}else{
71+
next_tokens[thread_idx] = end_ids[0];
72+
kwargs_next_tokens[thread_idx] = end_ids[0];
73+
}
6374
}
64-
// else {
65-
// next_tokens[thread_idx] = end_ids[0];
66-
// kwargs_next_tokens[thread_idx] = end_ids[0];
67-
// }
6875
} else {
6976
kwargs_next_tokens[thread_idx] = next_tokens[thread_idx];
7077
}
@@ -128,6 +135,15 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
128135
const int end_length = end_ids.shape()[0];
129136

130137
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
138+
int Flag_truncated_return_eos = 1;
139+
if (const char* inference_truncated_return_eos_env_p =
140+
std::getenv("INFERENCE_TRUNCATED_RETURN_EOS")) {
141+
std::string inference_truncated_return_eos_env_str(
142+
inference_truncated_return_eos_env_p);
143+
int inference_truncated_return_eos_from_env =
144+
std::stoi(inference_truncated_return_eos_env_str);
145+
Flag_truncated_return_eos = inference_truncated_return_eos_from_env;
146+
}
131147

132148
update_inputs_kernel_v2<1024><<<1, 1024, 0, input_ids.stream()>>>(
133149
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
@@ -146,7 +162,8 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
146162
now_bsz,
147163
max_bsz,
148164
input_ids_stride,
149-
end_length
165+
end_length,
166+
Flag_truncated_return_eos
150167
);
151168

152169
auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false);

llm/predict/predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,8 @@ def predict_dy_insert(
13381338
repeat_num=1,
13391339
**kwargs
13401340
):
1341+
# The output of the ultra-long truncation does not return an eos_token
1342+
os.environ["INFERENCE_TRUNCATED_RETURN_EOS"] = "0"
13411343
assert repeat_num >= 1
13421344
flag_current_rank_run = self.tensor_parallel_rank == 0 or all_rank_return
13431345
self.input_ids = []

0 commit comments

Comments
 (0)