@@ -42,8 +42,10 @@ __global__ void update_inputs_kernel_v2(
42
42
const int bsz,
43
43
const int max_bsz,
44
44
const int input_ids_stride,
45
- const int end_length) {
45
+ const int end_length,
46
+ const int Flag_truncated_return_eos) {
46
47
int thread_idx = threadIdx .x ;
48
+ bool output_len_truncated = false ;
47
49
// update step_idx and stop_flags
48
50
if (thread_idx < max_bsz) {
49
51
bool stop_flag = stop_flags[thread_idx];
@@ -52,6 +54,7 @@ __global__ void update_inputs_kernel_v2(
52
54
}
53
55
if (step_idx[thread_idx] >= max_dec_len[thread_idx]) {
54
56
stop_flags[thread_idx] = true ;
57
+ output_len_truncated = true ;
55
58
}
56
59
}
57
60
__syncthreads ();
@@ -60,11 +63,15 @@ __global__ void update_inputs_kernel_v2(
60
63
if (stop_flags[thread_idx]) {
61
64
if (seq_lens_this_time[thread_idx] == 0 ) {
62
65
next_tokens[thread_idx] = -1 ;
66
+ } else {
67
+ if (!Flag_truncated_return_eos && output_len_truncated) {
68
+ // output len truncated will not return eos for rl.
69
+ kwargs_next_tokens[thread_idx] = next_tokens[thread_idx];
70
+ }else {
71
+ next_tokens[thread_idx] = end_ids[0 ];
72
+ kwargs_next_tokens[thread_idx] = end_ids[0 ];
73
+ }
63
74
}
64
- // else {
65
- // next_tokens[thread_idx] = end_ids[0];
66
- // kwargs_next_tokens[thread_idx] = end_ids[0];
67
- // }
68
75
} else {
69
76
kwargs_next_tokens[thread_idx] = next_tokens[thread_idx];
70
77
}
@@ -128,6 +135,15 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
128
135
const int end_length = end_ids.shape ()[0 ];
129
136
130
137
auto not_need_stop_gpu = not_need_stop.copy_to (stop_flags.place (), false );
138
+ int Flag_truncated_return_eos = 1 ;
139
+ if (const char * inference_truncated_return_eos_env_p =
140
+ std::getenv (" INFERENCE_TRUNCATED_RETURN_EOS" )) {
141
+ std::string inference_truncated_return_eos_env_str (
142
+ inference_truncated_return_eos_env_p);
143
+ int inference_truncated_return_eos_from_env =
144
+ std::stoi (inference_truncated_return_eos_env_str);
145
+ Flag_truncated_return_eos = inference_truncated_return_eos_from_env;
146
+ }
131
147
132
148
update_inputs_kernel_v2<1024 ><<<1 , 1024 , 0 , input_ids.stream()>>> (
133
149
const_cast <bool *>(not_need_stop_gpu.data <bool >()),
@@ -146,7 +162,8 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
146
162
now_bsz,
147
163
max_bsz,
148
164
input_ids_stride,
149
- end_length
165
+ end_length,
166
+ Flag_truncated_return_eos
150
167
);
151
168
152
169
auto not_need_stop_cpu = not_need_stop_gpu.copy_to (not_need_stop.place (), false );
0 commit comments