Skip to content

Commit d2d196c

Browse files
Merge branch 'develop' into develop
2 parents ac9e6d7 + ddb10ac commit d2d196c

File tree

59 files changed

+421
-315
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+421
-315
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
4646
const paddle::Tensor& seq_lens_encoder,
4747
const paddle::Tensor& seq_lens_decoder,
4848
const paddle::Tensor& seq_lens_this_time,
49-
const paddle::Tensor& padding_offsets,
49+
const paddle::Tensor& batch_id_per_token,
5050
const paddle::Tensor& cu_seqlens_q,
5151
const paddle::Tensor& block_tables,
5252
const paddle::Tensor& encoder_batch_ids,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
165165
seq_lens_this_time,
166166
seq_lens_decoder,
167167
seq_lens_encoder,
168-
padding_offsets,
168+
batch_id_per_token,
169169
cu_seqlens_q,
170170
block_tables,
171171
lambda_batch_ids,
@@ -202,7 +202,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
202202
seq_lens_this_time,
203203
seq_lens_encoder,
204204
seq_lens_decoder,
205-
padding_offsets,
205+
batch_id_per_token,
206206
cu_seqlens_q,
207207
block_tables,
208208
kv_batch_ids,
@@ -274,7 +274,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
274274
qkv, // [token_num, num_heads, head_dim]
275275
seq_lens_decoder,
276276
seq_lens_encoder,
277-
padding_offsets,
277+
batch_id_per_token,
278278
cu_seqlens_q,
279279
block_tables,
280280
rotary_embs,
@@ -297,7 +297,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
297297
qkv_out, // [token_num, num_heads, head_dim]
298298
seq_lens_decoder,
299299
seq_lens_encoder,
300-
padding_offsets,
300+
batch_id_per_token,
301301
cu_seqlens_q,
302302
block_tables,
303303
rotary_embs,
@@ -322,7 +322,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
322322
qkv, // [token_num, num_heads, head_dim]
323323
seq_lens_decoder,
324324
seq_lens_encoder,
325-
padding_offsets,
325+
batch_id_per_token,
326326
cu_seqlens_q,
327327
block_tables,
328328
rotary_embs,
@@ -346,7 +346,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
346346
qkv_out, // [token_num, num_heads, head_dim]
347347
seq_lens_decoder,
348348
seq_lens_encoder,
349-
padding_offsets,
349+
batch_id_per_token,
350350
cu_seqlens_q,
351351
block_tables,
352352
rotary_embs,
@@ -403,7 +403,7 @@ std::vector<paddle::Tensor> AppendAttention(
403403
const paddle::Tensor& seq_lens_encoder,
404404
const paddle::Tensor& seq_lens_decoder,
405405
const paddle::Tensor& seq_lens_this_time,
406-
const paddle::Tensor& padding_offsets,
406+
const paddle::Tensor& batch_id_per_token,
407407
const paddle::Tensor& cu_seqlens_q,
408408
const paddle::Tensor& block_tables,
409409
const paddle::Tensor& encoder_batch_ids,
@@ -473,7 +473,7 @@ std::vector<paddle::Tensor> AppendAttention(
473473
seq_lens_encoder,
474474
seq_lens_decoder,
475475
seq_lens_this_time,
476-
padding_offsets,
476+
batch_id_per_token,
477477
cu_seqlens_q,
478478
block_tables,
479479
encoder_batch_ids,
@@ -550,7 +550,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
550550
const std::vector<int64_t>& seq_lens_encoder_shape,
551551
const std::vector<int64_t>& seq_lens_decoder_shape,
552552
const std::vector<int64_t>& seq_lens_this_time_shape,
553-
const std::vector<int64_t>& padding_offsets_shape,
553+
const std::vector<int64_t>& batch_id_per_token_shape,
554554
const std::vector<int64_t>& cu_seqlens_q_shape,
555555
const std::vector<int64_t>& block_tables_shape,
556556
const std::vector<int64_t>& encoder_batch_ids_shape,
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
610610
const paddle::DataType& seq_lens_encoder_dtype,
611611
const paddle::DataType& seq_lens_decoder_dtype,
612612
const paddle::DataType& seq_lens_this_time_dtype,
613-
const paddle::DataType& padding_offsets_dtype,
613+
const paddle::DataType& batch_id_per_token_dtype,
614614
const paddle::DataType& cu_seqlens_q_dtype,
615615
const paddle::DataType& block_tables_dtype,
616616
const paddle::DataType& encoder_batch_ids_dtype,
@@ -688,7 +688,7 @@ PD_BUILD_STATIC_OP(append_attention)
688688
"seq_lens_encoder",
689689
"seq_lens_decoder",
690690
"seq_lens_this_time",
691-
"padding_offsets",
691+
"batch_id_per_token",
692692
"cu_seqlens_q",
693693
"block_tables",
694694
"encoder_batch_ids",

custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ void MultiQueryAppendAttention(
773773
const paddle::Tensor &seq_lens_q,
774774
const paddle::Tensor &seq_lens_kv,
775775
const paddle::Tensor &seq_lens_encoder,
776-
const paddle::Tensor &padding_offsets,
776+
const paddle::Tensor &batch_id_per_token,
777777
const paddle::Tensor &cu_seqlens_q,
778778
const paddle::Tensor &block_table,
779779
const paddle::Tensor &batch_ids,
@@ -1007,7 +1007,8 @@ void MultiQueryAppendAttention(
10071007
seq_lens_q.data<int>(),
10081008
seq_lens_kv.data<int>(),
10091009
seq_lens_encoder.data<int>(),
1010-
padding_offsets.data<int>(),
1010+
batch_id_per_token.data<int>(),
1011+
cu_seqlens_q.data<int>(),
10111012
shift_bias ? reinterpret_cast<NV_TYPE *>(
10121013
const_cast<T *>(shift_bias.get().data<T>()))
10131014
: nullptr,
@@ -1240,7 +1241,8 @@ void MultiQueryAppendAttention(
12401241
seq_lens_q.data<int>(),
12411242
seq_lens_kv.data<int>(),
12421243
seq_lens_encoder.data<int>(),
1243-
padding_offsets.data<int>(),
1244+
batch_id_per_token.data<int>(),
1245+
cu_seqlens_q.data<int>(),
12441246
shift_bias ? reinterpret_cast<NV_TYPE *>(
12451247
const_cast<T *>(shift_bias.get().data<T>()))
12461248
: nullptr,
@@ -1287,7 +1289,7 @@ void CascadeAppendAttentionC16Kernel(
12871289
const paddle::Tensor& seq_lens_q,
12881290
const paddle::Tensor& seq_lens_kv,
12891291
const paddle::Tensor& seq_lens_encoder,
1290-
const paddle::Tensor& padding_offsets,
1292+
const paddle::Tensor& batch_id_per_token,
12911293
const paddle::Tensor& cu_seqlens_q,
12921294
const paddle::Tensor& block_table,
12931295
const paddle::Tensor& batch_ids,
@@ -1350,7 +1352,7 @@ void CascadeAppendAttentionC16Kernel(
13501352
seq_lens_q,
13511353
seq_lens_kv,
13521354
seq_lens_encoder,
1353-
padding_offsets,
1355+
batch_id_per_token,
13541356
cu_seqlens_q,
13551357
block_table,
13561358
batch_ids,

custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ void MultiQueryAppendC4Attention(
960960
const paddle::Tensor &seq_lens_q,
961961
const paddle::Tensor &seq_lens_kv,
962962
const paddle::Tensor &seq_lens_encoder,
963-
const paddle::Tensor &padding_offsets,
963+
const paddle::Tensor &batch_id_per_token,
964964
const paddle::Tensor &cu_seqlens_q,
965965
const paddle::Tensor &block_table,
966966
const paddle::Tensor &batch_ids,
@@ -1219,7 +1219,8 @@ void MultiQueryAppendC4Attention(
12191219
seq_lens_q.data<int>(),
12201220
seq_lens_kv.data<int>(),
12211221
seq_lens_encoder.data<int>(),
1222-
padding_offsets.data<int>(),
1222+
batch_id_per_token.data<int>(),
1223+
cu_seqlens_q.data<int>(),
12231224
shift_bias ? reinterpret_cast<NV_TYPE *>(
12241225
const_cast<T *>(shift_bias.get().data<T>()))
12251226
: nullptr,
@@ -1477,7 +1478,8 @@ void MultiQueryAppendC4Attention(
14771478
seq_lens_q.data<int>(),
14781479
seq_lens_kv.data<int>(),
14791480
seq_lens_encoder.data<int>(),
1480-
padding_offsets.data<int>(),
1481+
batch_id_per_token.data<int>(),
1482+
cu_seqlens_q.data<int>(),
14811483
shift_bias ? reinterpret_cast<NV_TYPE *>(
14821484
const_cast<T *>(shift_bias.get().data<T>()))
14831485
: nullptr,
@@ -1524,7 +1526,7 @@ void CascadeAppendAttentionC4Kernel(
15241526
const paddle::Tensor& seq_lens_q,
15251527
const paddle::Tensor& seq_lens_kv,
15261528
const paddle::Tensor& seq_lens_encoder,
1527-
const paddle::Tensor& padding_offsets,
1529+
const paddle::Tensor& batch_id_per_token,
15281530
const paddle::Tensor& cu_seqlens_q,
15291531
const paddle::Tensor& block_table,
15301532
const paddle::Tensor& batch_ids,
@@ -1591,7 +1593,7 @@ void CascadeAppendAttentionC4Kernel(
15911593
seq_lens_q,
15921594
seq_lens_kv,
15931595
seq_lens_encoder,
1594-
padding_offsets,
1596+
batch_id_per_token,
15951597
cu_seqlens_q,
15961598
block_table,
15971599
batch_ids,

custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ void MultiQueryAppendC8Attention(
897897
const paddle::Tensor &seq_lens_q,
898898
const paddle::Tensor &seq_lens_kv,
899899
const paddle::Tensor &seq_lens_encoder,
900-
const paddle::Tensor &padding_offsets,
900+
const paddle::Tensor &batch_id_per_token,
901901
const paddle::Tensor &cu_seqlens_q,
902902
const paddle::Tensor &block_table,
903903
const paddle::Tensor &batch_ids,
@@ -1179,7 +1179,8 @@ void MultiQueryAppendC8Attention(
11791179
seq_lens_q.data<int>(),
11801180
seq_lens_kv.data<int>(),
11811181
seq_lens_encoder.data<int>(),
1182-
padding_offsets.data<int>(),
1182+
batch_id_per_token.data<int>(),
1183+
cu_seqlens_q.data<int>(),
11831184
shift_bias ? reinterpret_cast<NV_TYPE *>(
11841185
const_cast<T *>(shift_bias.get().data<T>()))
11851186
: nullptr,
@@ -1450,7 +1451,8 @@ void MultiQueryAppendC8Attention(
14501451
seq_lens_q.data<int>(),
14511452
seq_lens_kv.data<int>(),
14521453
seq_lens_encoder.data<int>(),
1453-
padding_offsets.data<int>(),
1454+
batch_id_per_token.data<int>(),
1455+
cu_seqlens_q.data<int>(),
14541456
shift_bias ? reinterpret_cast<NV_TYPE *>(
14551457
const_cast<T *>(shift_bias.get().data<T>()))
14561458
: nullptr,
@@ -1497,7 +1499,7 @@ void CascadeAppendAttentionC8Kernel(
14971499
const paddle::Tensor& seq_lens_q,
14981500
const paddle::Tensor& seq_lens_kv,
14991501
const paddle::Tensor& seq_lens_encoder,
1500-
const paddle::Tensor& padding_offsets,
1502+
const paddle::Tensor& batch_id_per_token,
15011503
const paddle::Tensor& cu_seqlens_q,
15021504
const paddle::Tensor& block_table,
15031505
const paddle::Tensor& batch_ids,
@@ -1562,7 +1564,7 @@ void CascadeAppendAttentionC8Kernel(
15621564
seq_lens_q,
15631565
seq_lens_kv,
15641566
seq_lens_encoder,
1565-
padding_offsets,
1567+
batch_id_per_token,
15661568
cu_seqlens_q,
15671569
block_table,
15681570
batch_ids,

custom_ops/gpu_ops/append_attn/append_attention_func.cuh

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel(
18521852
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
18531853
const int* __restrict__ seq_lens_q,
18541854
const int* __restrict__ seq_lens_kv,
1855-
const int* __restrict__ padding_offsets,
1855+
const int* __restrict__ batch_id_per_token,
18561856
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
18571857
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
18581858
T* __restrict__ out,
@@ -1866,8 +1866,7 @@ __global__ void merge_multi_chunks_kernel(
18661866
const int head_dim) {
18671867
const int vid = threadIdx.x, hid = threadIdx.y;
18681868
const int qid = blockIdx.x;
1869-
const uint32_t ori_token_id = qid + padding_offsets[qid];
1870-
const uint32_t bid = ori_token_id / max_seq_len;
1869+
const uint32_t bid = batch_id_per_token[qid];
18711870
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
18721871
return;
18731872
}
@@ -2240,7 +2239,8 @@ __global__ void merge_multi_chunks_v2_kernel(
22402239
const int *__restrict__ seq_lens_q,
22412240
const int *__restrict__ seq_lens_kv,
22422241
const int *__restrict__ seq_lens_encoder,
2243-
const int *__restrict__ padding_offsets,
2242+
const int *__restrict__ batch_id_per_token,
2243+
const int *__restrict__ cu_seqlens_q,
22442244
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
22452245
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
22462246
OutT *__restrict__ out,
@@ -2259,9 +2259,8 @@ __global__ void merge_multi_chunks_v2_kernel(
22592259
__shared__ T smem[bdy * HEAD_DIM];
22602260
__shared__ float md_smem[bdy * 2];
22612261
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
2262-
const uint32_t ori_token_id = qid + padding_offsets[qid];
2263-
const uint32_t bid = ori_token_id / max_seq_len;
2264-
const uint32_t local_seq_id = ori_token_id % max_seq_len;
2262+
const uint32_t bid = batch_id_per_token[qid];
2263+
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
22652264
const int seq_len_q = seq_lens_q[bid];
22662265
if (seq_len_q == 0) continue;
22672266
int seq_len_kv = seq_lens_kv[bid];

custom_ops/gpu_ops/append_attn/append_attention_kernel.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void CascadeAppendAttentionC16Kernel(
4040
const paddle::Tensor& seq_lens_q,
4141
const paddle::Tensor& seq_lens_kv,
4242
const paddle::Tensor& seq_lens_encoder,
43-
const paddle::Tensor& padding_offsets,
43+
const paddle::Tensor& batch_id_per_token,
4444
const paddle::Tensor& cu_seqlens_q,
4545
const paddle::Tensor& block_table,
4646
const paddle::Tensor& batch_ids,
@@ -85,7 +85,7 @@ void CascadeAppendAttentionC8Kernel(
8585
const paddle::Tensor& seq_lens_q,
8686
const paddle::Tensor& seq_lens_kv,
8787
const paddle::Tensor& seq_lens_encoder,
88-
const paddle::Tensor& padding_offsets,
88+
const paddle::Tensor& batch_id_per_token,
8989
const paddle::Tensor& cu_seqlens_q,
9090
const paddle::Tensor& block_table,
9191
const paddle::Tensor& batch_ids,
@@ -130,7 +130,7 @@ void CascadeAppendAttentionC4Kernel(
130130
const paddle::Tensor& seq_lens_q,
131131
const paddle::Tensor& seq_lens_kv,
132132
const paddle::Tensor& seq_lens_encoder,
133-
const paddle::Tensor& padding_offsets,
133+
const paddle::Tensor& batch_id_per_token,
134134
const paddle::Tensor& cu_seqlens_q,
135135
const paddle::Tensor& block_table,
136136
const paddle::Tensor& batch_ids,
@@ -175,7 +175,7 @@ void CascadeAppendAttentionKernel(
175175
const paddle::Tensor& seq_lens_q,
176176
const paddle::Tensor& seq_lens_kv,
177177
const paddle::Tensor& seq_lens_encoder,
178-
const paddle::Tensor& padding_offsets,
178+
const paddle::Tensor& batch_id_per_token,
179179
const paddle::Tensor& cu_seqlens_q,
180180
const paddle::Tensor& block_table,
181181
const paddle::Tensor& batch_ids,
@@ -211,7 +211,7 @@ void CascadeAppendAttentionKernel(
211211
seq_lens_q,
212212
seq_lens_kv,
213213
seq_lens_encoder,
214-
padding_offsets,
214+
batch_id_per_token,
215215
cu_seqlens_q,
216216
block_table,
217217
batch_ids,
@@ -246,7 +246,7 @@ void CascadeAppendAttentionKernel(
246246
seq_lens_q,
247247
seq_lens_kv,
248248
seq_lens_encoder,
249-
padding_offsets,
249+
batch_id_per_token,
250250
cu_seqlens_q,
251251
block_table,
252252
batch_ids,
@@ -281,7 +281,7 @@ void CascadeAppendAttentionKernel(
281281
seq_lens_q,
282282
seq_lens_kv,
283283
seq_lens_encoder,
284-
padding_offsets,
284+
batch_id_per_token,
285285
cu_seqlens_q,
286286
block_table,
287287
batch_ids,
@@ -316,7 +316,7 @@ void CascadeAppendAttentionKernel(
316316
seq_lens_q,
317317
seq_lens_kv,
318318
seq_lens_encoder,
319-
padding_offsets,
319+
batch_id_per_token,
320320
cu_seqlens_q,
321321
block_table,
322322
batch_ids,

0 commit comments

Comments
 (0)