Skip to content

Commit ac9e6d7

Browse files
authored
Merge branch 'develop' into develop
2 parents 469b1d7 + 67180c1 commit ac9e6d7

26 files changed

+637
-437
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ name: CI
22

33
on:
44
pull_request:
5-
branches: [ develop ]
5+
branches:
6+
- develop
7+
- 'release/*'
68
workflow_dispatch:
79

810
concurrency:

.github/workflows/ci_xpu.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ name: CI_XPU
22

33
on:
44
pull_request:
5-
branches: [ develop ]
5+
branches:
6+
- develop
7+
- 'release/*'
68
workflow_dispatch:
79

810
concurrency:

custom_ops/gpu_ops/token_penalty_multi_scores.cu

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ __global__ inline void min_length_logits_process(T *logits,
2020
const int64_t *min_len,
2121
const int64_t *eos_token_id,
2222
const int64_t bs,
23-
const int64_t length,
24-
const int64_t end_length) {
23+
const int64_t vocab_size,
24+
const int64_t eos_len) {
2525
int bi = threadIdx.x;
2626
if (bi >= bs) return;
2727
if (cur_len[bi] < 0) {
2828
return;
2929
}
3030
if (cur_len[bi] < min_len[bi]) {
31-
for (int i = 0; i < end_length; i++) {
32-
logits[bi * length + eos_token_id[i]] = -1e10;
31+
for (int i = 0; i < eos_len; i++) {
32+
logits[bi * vocab_size + eos_token_id[i]] = -1e10;
3333
}
3434
}
3535
}
@@ -41,61 +41,83 @@ __global__ inline void min_length_logits_process<half>(
4141
const int64_t *min_len,
4242
const int64_t *eos_token_id,
4343
const int64_t bs,
44-
const int64_t length,
45-
const int64_t end_length) {
44+
const int64_t vocab_size,
45+
const int64_t eos_len) {
4646
int bi = threadIdx.x;
4747
if (bi >= bs) return;
4848
if (cur_len[bi] < 0) {
4949
return;
5050
}
5151
if (cur_len[bi] < min_len[bi]) {
52-
for (int i = 0; i < end_length; i++) {
53-
logits[bi * length + eos_token_id[i]] = -1e4;
52+
for (int i = 0; i < eos_len; i++) {
53+
logits[bi * vocab_size + eos_token_id[i]] = -1e4;
5454
}
5555
}
5656
}
5757

5858
__global__ void update_repeat_times(const int64_t *pre_ids,
59+
const int64_t *prompt_ids,
60+
const int64_t *prompt_len,
5961
const int64_t *cur_len,
6062
int *repeat_times,
63+
int *is_repeated,
6164
const int64_t bs,
62-
const int64_t length,
63-
const int64_t length_id) {
64-
int bi = blockIdx.x;
65+
const int64_t vocab_size,
66+
const int64_t max_dec_len,
67+
const int64_t max_model_len) {
68+
int64_t bi = blockIdx.x;
6569
if (cur_len[bi] < 0) {
6670
return;
6771
}
68-
int tid = threadIdx.x;
69-
const int64_t *pre_ids_now = pre_ids + bi * length_id;
70-
int *repeat_times_now = repeat_times + bi * length;
71-
for (int i = tid; i < length_id; i += blockDim.x) {
72-
int64_t id = pre_ids_now[i];
73-
if (id < 0) break;
74-
atomicAdd(&repeat_times_now[id], 1);
72+
const int64_t prompt_len_now = prompt_len[bi];
73+
int64_t tid = threadIdx.x;
74+
const int64_t *prompt_now = prompt_ids + bi * max_model_len;
75+
const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
76+
int *repeat_times_now = repeat_times + bi * vocab_size;
77+
int *is_repeated_now = is_repeated + bi * vocab_size;
78+
const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
79+
for (int64_t i = tid; i < loop_len; i += blockDim.x) {
80+
if (i < max_dec_len) {
81+
int64_t id = pre_ids_now[i];
82+
if (id >= 0) {
83+
atomicAdd(&repeat_times_now[id], 1);
84+
atomicAdd(&is_repeated_now[id], 1);
85+
}
86+
}
87+
if (i < prompt_len_now) {
88+
int64_t id = prompt_now[i];
89+
if (id >= 0) {
90+
atomicAdd(&is_repeated_now[id], 1);
91+
}
92+
}
7593
}
7694
}
7795

7896
template <typename T>
7997
__global__ void update_value_by_repeat_times(const int *repeat_times,
98+
const int *is_repeated,
8099
const T *penalty_scores,
81100
const T *frequency_score,
82101
const T *presence_score,
83102
const float *temperatures,
84103
T *logits,
85104
const int64_t bs,
86-
const int64_t length) {
105+
const int64_t vocab_size) {
87106
int bi = blockIdx.x;
88107
int tid = threadIdx.x;
89-
T *logits_now = logits + bi * length;
90-
const int *repeat_times_now = repeat_times + bi * length;
108+
T *logits_now = logits + bi * vocab_size;
109+
const int *repeat_times_now = repeat_times + bi * vocab_size;
110+
const int *is_repeated_now = is_repeated + bi * vocab_size;
91111
float alpha = static_cast<float>(penalty_scores[bi]);
92112
float beta = static_cast<float>(frequency_score[bi]);
93113
float gamma = static_cast<float>(presence_score[bi]);
94-
for (int i = tid; i < length; i += blockDim.x) {
114+
for (int i = tid; i < vocab_size; i += blockDim.x) {
95115
int times = repeat_times_now[i];
96116
float logit_now = static_cast<float>(logits_now[i]);
97-
if (times != 0) {
117+
if (is_repeated_now[i] != 0) {
98118
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
119+
}
120+
if (times != 0) {
99121
logit_now = logit_now - times * beta - gamma;
100122
}
101123
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
@@ -106,20 +128,22 @@ template <typename T>
106128
__global__ void ban_bad_words(T *logits,
107129
const int64_t *bad_words_list,
108130
const int64_t bs,
109-
const int64_t length,
110-
const int64_t bad_words_length) {
131+
const int64_t vocab_size,
132+
const int64_t bad_words_len) {
111133
const int bi = blockIdx.x;
112134
int tid = threadIdx.x;
113-
T *logits_now = logits + bi * length;
114-
for (int i = tid; i < bad_words_length; i += blockDim.x) {
135+
T *logits_now = logits + bi * vocab_size;
136+
for (int i = tid; i < bad_words_len; i += blockDim.x) {
115137
const int64_t bad_words_token_id = bad_words_list[i];
116-
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
138+
if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue;
117139
logits_now[bad_words_token_id] = -1e10;
118140
}
119141
}
120142

121143
template <paddle::DataType D>
122144
void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
145+
const paddle::Tensor &prompt_ids,
146+
const paddle::Tensor &prompt_len,
123147
const paddle::Tensor &logits,
124148
const paddle::Tensor &penalty_scores,
125149
const paddle::Tensor &frequency_score,
@@ -141,12 +165,15 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
141165
std::vector<int64_t> shape = logits.shape();
142166
auto repeat_times =
143167
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
168+
auto is_repeated =
169+
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
144170
int64_t bs = shape[0];
145-
int64_t length = shape[1];
146-
int64_t length_id = pre_ids.shape()[1];
147-
int64_t length_bad_words = bad_tokens.shape()[0];
148171

149-
int64_t end_length = eos_token_id.shape()[0];
172+
int64_t vocab_size = shape[1];
173+
int64_t max_dec_len = pre_ids.shape()[1];
174+
int64_t bad_words_len = bad_tokens.shape()[0];
175+
int64_t eos_len = eos_token_id.shape()[0];
176+
int64_t max_model_len = prompt_ids.shape()[1];
150177

151178
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
152179
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
@@ -156,31 +183,36 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
156183
min_len.data<int64_t>(),
157184
eos_token_id.data<int64_t>(),
158185
bs,
159-
length,
160-
end_length);
186+
vocab_size,
187+
eos_len);
161188

162-
block_size = (length_id + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
189+
block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
163190
#ifdef PADDLE_WITH_COREX
164191
block_size = std::min(block_size, 512);
165192
#else
166193
block_size = min(block_size, 512);
167194
#endif
168195
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
169196
pre_ids.data<int64_t>(),
197+
prompt_ids.data<int64_t>(),
198+
prompt_len.data<int64_t>(),
170199
cur_len.data<int64_t>(),
171200
repeat_times.data<int>(),
201+
is_repeated.data<int>(),
172202
bs,
173-
length,
174-
length_id);
203+
vocab_size,
204+
max_dec_len,
205+
max_model_len);
175206

176-
block_size = (length + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
207+
block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
177208
#ifdef PADDLE_WITH_COREX
178209
block_size = std::min(block_size, 512);
179210
#else
180211
block_size = min(block_size, 512);
181212
#endif
182213
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
183214
repeat_times.data<int>(),
215+
is_repeated.data<int>(),
184216
reinterpret_cast<DataType_ *>(
185217
const_cast<data_t *>(penalty_scores.data<data_t>())),
186218
reinterpret_cast<DataType_ *>(
@@ -191,9 +223,9 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
191223
reinterpret_cast<DataType_ *>(
192224
const_cast<data_t *>(logits.data<data_t>())),
193225
bs,
194-
length);
226+
vocab_size);
195227

196-
block_size = (length_bad_words + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
228+
block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
197229
#ifdef PADDLE_WITH_COREX
198230
block_size = std::min(block_size, 512);
199231
#else
@@ -204,11 +236,13 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
204236
const_cast<data_t *>(logits.data<data_t>())),
205237
bad_tokens.data<int64_t>(),
206238
bs,
207-
length,
208-
length_bad_words);
239+
vocab_size,
240+
bad_words_len);
209241
}
210242

211243
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
244+
const paddle::Tensor &prompt_ids,
245+
const paddle::Tensor &prompt_len,
212246
const paddle::Tensor &logits,
213247
const paddle::Tensor &penalty_scores,
214248
const paddle::Tensor &frequency_scores,
@@ -222,6 +256,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
222256
case paddle::DataType::BFLOAT16: {
223257
return token_penalty_multi_scores_kernel<
224258
paddle::DataType::BFLOAT16>(pre_ids,
259+
prompt_ids,
260+
prompt_len,
225261
logits,
226262
penalty_scores,
227263
frequency_scores,
@@ -233,30 +269,34 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
233269
eos_token_id);
234270
}
235271
case paddle::DataType::FLOAT16: {
236-
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
237-
pre_ids,
238-
logits,
239-
penalty_scores,
240-
frequency_scores,
241-
presence_scores,
242-
temperatures,
243-
bad_tokens,
244-
cur_len,
245-
min_len,
246-
eos_token_id);
272+
return token_penalty_multi_scores_kernel<
273+
paddle::DataType::FLOAT16>(pre_ids,
274+
prompt_ids,
275+
prompt_len,
276+
logits,
277+
penalty_scores,
278+
frequency_scores,
279+
presence_scores,
280+
temperatures,
281+
bad_tokens,
282+
cur_len,
283+
min_len,
284+
eos_token_id);
247285
}
248286
case paddle::DataType::FLOAT32: {
249-
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
250-
pre_ids,
251-
logits,
252-
penalty_scores,
253-
frequency_scores,
254-
presence_scores,
255-
temperatures,
256-
bad_tokens,
257-
cur_len,
258-
min_len,
259-
eos_token_id);
287+
return token_penalty_multi_scores_kernel<
288+
paddle::DataType::FLOAT32>(pre_ids,
289+
prompt_ids,
290+
prompt_len,
291+
logits,
292+
penalty_scores,
293+
frequency_scores,
294+
presence_scores,
295+
temperatures,
296+
bad_tokens,
297+
cur_len,
298+
min_len,
299+
eos_token_id);
260300
}
261301
default: {
262302
PD_THROW(
@@ -269,6 +309,8 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
269309

270310
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)
271311
.Inputs({"pre_ids",
312+
"prompt_ids",
313+
"prompt_len",
272314
"logits",
273315
"penalty_scores",
274316
"frequency_scores",

0 commit comments

Comments
 (0)