@@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits,
20
20
const int64_t *min_len,
21
21
const int64_t *eos_token_id,
22
22
const int64_t bs,
23
- const int64_t length ,
24
- const int64_t end_length ) {
23
+ const int64_t vocab_size ,
24
+ const int64_t eos_len ) {
25
25
int bi = threadIdx .x ;
26
26
if (bi >= bs) return ;
27
27
if (cur_len[bi] < 0 ) {
28
28
return ;
29
29
}
30
30
if (cur_len[bi] < min_len[bi]) {
31
- for (int i = 0 ; i < end_length ; i++) {
32
- logits[bi * length + eos_token_id[i]] = -1e10 ;
31
+ for (int i = 0 ; i < eos_len ; i++) {
32
+ logits[bi * vocab_size + eos_token_id[i]] = -1e10 ;
33
33
}
34
34
}
35
35
}
@@ -41,61 +41,83 @@ __global__ inline void min_length_logits_process<half>(
41
41
const int64_t *min_len,
42
42
const int64_t *eos_token_id,
43
43
const int64_t bs,
44
- const int64_t length ,
45
- const int64_t end_length ) {
44
+ const int64_t vocab_size ,
45
+ const int64_t eos_len ) {
46
46
int bi = threadIdx .x ;
47
47
if (bi >= bs) return ;
48
48
if (cur_len[bi] < 0 ) {
49
49
return ;
50
50
}
51
51
if (cur_len[bi] < min_len[bi]) {
52
- for (int i = 0 ; i < end_length ; i++) {
53
- logits[bi * length + eos_token_id[i]] = -1e4 ;
52
+ for (int i = 0 ; i < eos_len ; i++) {
53
+ logits[bi * vocab_size + eos_token_id[i]] = -1e4 ;
54
54
}
55
55
}
56
56
}
57
57
58
58
__global__ void update_repeat_times (const int64_t *pre_ids,
59
+ const int64_t *prompt_ids,
60
+ const int64_t *prompt_len,
59
61
const int64_t *cur_len,
60
62
int *repeat_times,
63
+ int *is_repeated,
61
64
const int64_t bs,
62
- const int64_t length,
63
- const int64_t length_id) {
64
- int bi = blockIdx .x ;
65
+ const int64_t vocab_size,
66
+ const int64_t max_dec_len,
67
+ const int64_t max_model_len) {
68
+ int64_t bi = blockIdx .x ;
65
69
if (cur_len[bi] < 0 ) {
66
70
return ;
67
71
}
68
- int tid = threadIdx .x ;
69
- const int64_t *pre_ids_now = pre_ids + bi * length_id;
70
- int *repeat_times_now = repeat_times + bi * length;
71
- for (int i = tid; i < length_id; i += blockDim .x ) {
72
- int64_t id = pre_ids_now[i];
73
- if (id < 0 ) break ;
74
- atomicAdd (&repeat_times_now[id], 1 );
72
+ const int64_t prompt_len_now = prompt_len[bi];
73
+ int64_t tid = threadIdx .x ;
74
+ const int64_t *prompt_now = prompt_ids + bi * max_model_len;
75
+ const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
76
+ int *repeat_times_now = repeat_times + bi * vocab_size;
77
+ int *is_repeated_now = is_repeated + bi * vocab_size;
78
+ const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
79
+ for (int64_t i = tid; i < loop_len; i += blockDim .x ) {
80
+ if (i < max_dec_len) {
81
+ int64_t id = pre_ids_now[i];
82
+ if (id >= 0 ) {
83
+ atomicAdd (&repeat_times_now[id], 1 );
84
+ atomicAdd (&is_repeated_now[id], 1 );
85
+ }
86
+ }
87
+ if (i < prompt_len_now) {
88
+ int64_t id = prompt_now[i];
89
+ if (id >= 0 ) {
90
+ atomicAdd (&is_repeated_now[id], 1 );
91
+ }
92
+ }
75
93
}
76
94
}
77
95
78
96
template <typename T>
79
97
__global__ void update_value_by_repeat_times (const int *repeat_times,
98
+ const int *is_repeated,
80
99
const T *penalty_scores,
81
100
const T *frequency_score,
82
101
const T *presence_score,
83
102
const float *temperatures,
84
103
T *logits,
85
104
const int64_t bs,
86
- const int64_t length ) {
105
+ const int64_t vocab_size ) {
87
106
int bi = blockIdx .x ;
88
107
int tid = threadIdx .x ;
89
- T *logits_now = logits + bi * length;
90
- const int *repeat_times_now = repeat_times + bi * length;
108
+ T *logits_now = logits + bi * vocab_size;
109
+ const int *repeat_times_now = repeat_times + bi * vocab_size;
110
+ const int *is_repeated_now = is_repeated + bi * vocab_size;
91
111
float alpha = static_cast <float >(penalty_scores[bi]);
92
112
float beta = static_cast <float >(frequency_score[bi]);
93
113
float gamma = static_cast <float >(presence_score[bi]);
94
- for (int i = tid; i < length ; i += blockDim .x ) {
114
+ for (int i = tid; i < vocab_size ; i += blockDim .x ) {
95
115
int times = repeat_times_now[i];
96
116
float logit_now = static_cast <float >(logits_now[i]);
97
- if (times != 0 ) {
117
+ if (is_repeated_now[i] != 0 ) {
98
118
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
119
+ }
120
+ if (times != 0 ) {
99
121
logit_now = logit_now - times * beta - gamma;
100
122
}
101
123
logits_now[i] = static_cast <T>(logit_now / temperatures[bi]);
@@ -106,20 +128,22 @@ template <typename T>
106
128
__global__ void ban_bad_words (T *logits,
107
129
const int64_t *bad_words_list,
108
130
const int64_t bs,
109
- const int64_t length ,
110
- const int64_t bad_words_length ) {
131
+ const int64_t vocab_size ,
132
+ const int64_t bad_words_len ) {
111
133
const int bi = blockIdx .x ;
112
134
int tid = threadIdx .x ;
113
- T *logits_now = logits + bi * length ;
114
- for (int i = tid; i < bad_words_length ; i += blockDim .x ) {
135
+ T *logits_now = logits + bi * vocab_size ;
136
+ for (int i = tid; i < bad_words_len ; i += blockDim .x ) {
115
137
const int64_t bad_words_token_id = bad_words_list[i];
116
- if (bad_words_token_id >= length || bad_words_token_id < 0 ) continue ;
138
+ if (bad_words_token_id >= vocab_size || bad_words_token_id < 0 ) continue ;
117
139
logits_now[bad_words_token_id] = -1e10 ;
118
140
}
119
141
}
120
142
121
143
template <paddle::DataType D>
122
144
void token_penalty_multi_scores_kernel (const paddle::Tensor &pre_ids,
145
+ const paddle::Tensor &prompt_ids,
146
+ const paddle::Tensor &prompt_len,
123
147
const paddle::Tensor &logits,
124
148
const paddle::Tensor &penalty_scores,
125
149
const paddle::Tensor &frequency_score,
@@ -141,12 +165,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
141
165
std::vector<int64_t > shape = logits.shape ();
142
166
auto repeat_times =
143
167
paddle::full (shape, 0 , paddle::DataType::INT32, pre_ids.place ());
168
+ auto is_repeated =
169
+ paddle::full (shape, 0 , paddle::DataType::INT32, pre_ids.place ());
144
170
int64_t bs = shape[0 ];
145
- int64_t length = shape[1 ];
146
- int64_t length_id = pre_ids.shape ()[1 ];
147
- int64_t length_bad_words = bad_tokens.shape ()[0 ];
148
171
149
- int64_t end_length = eos_token_id.shape ()[0 ];
172
+ int64_t vocab_size = shape[1 ];
173
+ int64_t max_dec_len = pre_ids.shape ()[1 ];
174
+ int64_t bad_words_len = bad_tokens.shape ()[0 ];
175
+ int64_t eos_len = eos_token_id.shape ()[0 ];
176
+ int64_t max_model_len = prompt_ids.shape ()[1 ];
150
177
151
178
int block_size = (bs + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
152
179
min_length_logits_process<<<1 , block_size, 0 , cu_stream>>> (
@@ -156,31 +183,36 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
156
183
min_len.data <int64_t >(),
157
184
eos_token_id.data <int64_t >(),
158
185
bs,
159
- length ,
160
- end_length );
186
+ vocab_size ,
187
+ eos_len );
161
188
162
- block_size = (length_id + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
189
+ block_size = (max_dec_len + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
163
190
#ifdef PADDLE_WITH_COREX
164
191
block_size = std::min (block_size, 512 );
165
192
#else
166
193
block_size = min (block_size, 512 );
167
194
#endif
168
195
update_repeat_times<<<bs, block_size, 0 , cu_stream>>> (
169
196
pre_ids.data <int64_t >(),
197
+ prompt_ids.data <int64_t >(),
198
+ prompt_len.data <int64_t >(),
170
199
cur_len.data <int64_t >(),
171
200
repeat_times.data <int >(),
201
+ is_repeated.data <int >(),
172
202
bs,
173
- length,
174
- length_id);
203
+ vocab_size,
204
+ max_dec_len,
205
+ max_model_len);
175
206
176
- block_size = (length + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
207
+ block_size = (vocab_size + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
177
208
#ifdef PADDLE_WITH_COREX
178
209
block_size = std::min (block_size, 512 );
179
210
#else
180
211
block_size = min (block_size, 512 );
181
212
#endif
182
213
update_value_by_repeat_times<DataType_><<<bs, block_size, 0 , cu_stream>>> (
183
214
repeat_times.data <int >(),
215
+ is_repeated.data <int >(),
184
216
reinterpret_cast <DataType_ *>(
185
217
const_cast <data_t *>(penalty_scores.data <data_t >())),
186
218
reinterpret_cast <DataType_ *>(
@@ -191,9 +223,9 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
191
223
reinterpret_cast <DataType_ *>(
192
224
const_cast <data_t *>(logits.data <data_t >())),
193
225
bs,
194
- length );
226
+ vocab_size );
195
227
196
- block_size = (length_bad_words + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
228
+ block_size = (bad_words_len + WARP_SIZE - 1 ) / WARP_SIZE * WARP_SIZE;
197
229
#ifdef PADDLE_WITH_COREX
198
230
block_size = std::min (block_size, 512 );
199
231
#else
@@ -204,11 +236,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
204
236
const_cast <data_t *>(logits.data <data_t >())),
205
237
bad_tokens.data <int64_t >(),
206
238
bs,
207
- length ,
208
- length_bad_words );
239
+ vocab_size ,
240
+ bad_words_len );
209
241
}
210
242
211
243
void TokenPenaltyMultiScores (const paddle::Tensor &pre_ids,
244
+ const paddle::Tensor &prompt_ids,
245
+ const paddle::Tensor &prompt_len,
212
246
const paddle::Tensor &logits,
213
247
const paddle::Tensor &penalty_scores,
214
248
const paddle::Tensor &frequency_scores,
@@ -222,6 +256,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
222
256
case paddle::DataType::BFLOAT16: {
223
257
return token_penalty_multi_scores_kernel<
224
258
paddle::DataType::BFLOAT16>(pre_ids,
259
+ prompt_ids,
260
+ prompt_len,
225
261
logits,
226
262
penalty_scores,
227
263
frequency_scores,
@@ -233,30 +269,34 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
233
269
eos_token_id);
234
270
}
235
271
case paddle::DataType::FLOAT16: {
236
- return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
237
- pre_ids,
238
- logits,
239
- penalty_scores,
240
- frequency_scores,
241
- presence_scores,
242
- temperatures,
243
- bad_tokens,
244
- cur_len,
245
- min_len,
246
- eos_token_id);
272
+ return token_penalty_multi_scores_kernel<
273
+ paddle::DataType::FLOAT16>(pre_ids,
274
+ prompt_ids,
275
+ prompt_len,
276
+ logits,
277
+ penalty_scores,
278
+ frequency_scores,
279
+ presence_scores,
280
+ temperatures,
281
+ bad_tokens,
282
+ cur_len,
283
+ min_len,
284
+ eos_token_id);
247
285
}
248
286
case paddle::DataType::FLOAT32: {
249
- return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
250
- pre_ids,
251
- logits,
252
- penalty_scores,
253
- frequency_scores,
254
- presence_scores,
255
- temperatures,
256
- bad_tokens,
257
- cur_len,
258
- min_len,
259
- eos_token_id);
287
+ return token_penalty_multi_scores_kernel<
288
+ paddle::DataType::FLOAT32>(pre_ids,
289
+ prompt_ids,
290
+ prompt_len,
291
+ logits,
292
+ penalty_scores,
293
+ frequency_scores,
294
+ presence_scores,
295
+ temperatures,
296
+ bad_tokens,
297
+ cur_len,
298
+ min_len,
299
+ eos_token_id);
260
300
}
261
301
default : {
262
302
PD_THROW (
@@ -269,6 +309,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
269
309
270
310
PD_BUILD_STATIC_OP (get_token_penalty_multi_scores)
271
311
.Inputs({" pre_ids" ,
312
+ " prompt_ids" ,
313
+ " prompt_len" ,
272
314
" logits" ,
273
315
" penalty_scores" ,
274
316
" frequency_scores" ,
0 commit comments