Skip to content

Commit 7ec2a1c

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Padded KV Partial Prefill Case (#4848)
Summary: X-link: facebookresearch/FBGEMM#1879 Pull Request resolved: #4848 The current KV padding only suppported full prefill case (D78967317). This diff adds partial prefill support as well. Coverage added in the tests. WIP: upstreaming this. ( D78967317 and this diff) Reviewed By: sryap Differential Revision: D82080682 fbshipit-source-id: 7a6c7a0d3c32245e5c13864b1f0cfe37d8d254c4
1 parent 23f944c commit 7ec2a1c

File tree

3 files changed

+46
-15
lines changed

3 files changed

+46
-15
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_fusion.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,29 @@ apply_variable_length_offset(Shape const& shape, Coord const& coord) {
636636
return cute::make_tuple(result_shape, result_offset);
637637
}
638638

639+
template <class Shape, class Idx>
640+
CUTE_HOST_DEVICE constexpr auto apply_variable_length_paddedkv(
641+
Shape const& shape,
642+
Idx const& idx,
643+
int kv_length) {
644+
// Use a position counter to track which element we're processing
645+
int position_counter = 0;
646+
647+
return transform_leaf(shape, [&](auto const& s) {
648+
if constexpr (is_variable_length_v<decltype(s)>) {
649+
int current_pos = position_counter++;
650+
if (current_pos == 1) {
651+
return kv_length;
652+
} else {
653+
return s.cumulative_length[idx + 1] - s.cumulative_length[idx];
654+
}
655+
} else {
656+
position_counter++;
657+
return s;
658+
}
659+
});
660+
}
661+
639662
} // namespace cutlass::fmha::collective
640663

641664
namespace cute {

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
255255
// If seqlen_kv is provided, use it to determine the sequence length for
256256
// key-value pairs
257257
if (params.seqlen_kv != nullptr) {
258-
return transform_leaf(problem_shape, [&](auto const& s) {
259-
if constexpr (is_variable_length_v<decltype(s)>) {
260-
return params.seqlen_kv[batch_idx];
261-
} else {
262-
return s;
263-
}
264-
});
258+
// Position-aware replacement that only replaces K/V (position 1)
259+
return apply_variable_length_paddedkv(
260+
problem_shape, batch_idx, params.seqlen_kv[batch_idx]);
265261
} else {
266262
// Fall back to the original behavior
267263
return apply_variable_length(params.problem_shape, batch_idx);

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,13 @@ def test_decode(
438438
(
439439
kv_padding,
440440
batch_size,
441+
q_heads,
441442
causal,
442443
window_size,
443444
)
444445
for kv_padding in [128, 256, 512, 1024]
445446
for batch_size in [2, 8]
447+
for q_heads in [8, 16]
446448
for causal in [True, False]
447449
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
448450
]
@@ -451,6 +453,7 @@ def test_jagged_vs_padded_kv(
451453
self,
452454
kv_padding: int,
453455
batch_size: int,
456+
q_heads: int,
454457
causal: bool,
455458
window_size: tuple[int, int] = (-1, -1),
456459
) -> None:
@@ -465,11 +468,9 @@ def test_jagged_vs_padded_kv(
465468
# kv_padding = 128
466469
seqlen_q = kv_padding # Maximum sequence length (padded size)
467470
device = torch.accelerator.current_accelerator()
468-
q_heads = 1
469471
kv_heads = 1
470472
head_dim = 128
471473
dtype = torch.bfloat16
472-
causal = False
473474

474475
# Create tensors
475476
q_padded = torch.randn(
@@ -499,11 +500,14 @@ def test_jagged_vs_padded_kv(
499500
device=device,
500501
).to(dtype)
501502

502-
qk_padding_mask = generate_random_padding_mask(
503+
k_padding_mask = generate_random_padding_mask(
504+
kv_padding, batch_size, device, mode="random", zero_lengths=False
505+
)
506+
q_padding_mask = generate_random_padding_mask(
503507
kv_padding, batch_size, device, mode="third", zero_lengths=False
504508
)
505509
# # Always have seqlen_k >= seqlen_q
506-
# key_padding_mask[:, :seqlen_q] |= query_padding_mask
510+
k_padding_mask[:, :seqlen_q] |= q_padding_mask
507511
(
508512
q_unpad,
509513
k_unpad,
@@ -524,8 +528,8 @@ def test_jagged_vs_padded_kv(
524528
q_padded,
525529
k_padded,
526530
v_padded,
527-
qk_padding_mask,
528-
qk_padding_mask,
531+
q_padding_mask,
532+
k_padding_mask,
529533
)
530534
# Create variable length sequences
531535
cu_seqlens_k_padded = torch.zeros(
@@ -546,6 +550,9 @@ def test_jagged_vs_padded_kv(
546550
print(f"jagged cu_seqlens_k: {cu_seqlens_k_jagged}")
547551
print(f"padded cu_seqlens_k: {cu_seqlens_k_padded}")
548552
print(f"seqlen_kv: {seqused_k}")
553+
print(f"max_seqlen_q: {max_seqlen_q}")
554+
print(f"max_seqlen_k: {max_seqlen_k}")
555+
print(f"q_unpad: {q_unpad.shape}")
549556

550557
# Scenario A: Jagged KV with cu_seqlens_k
551558
out_jagged = cutlass_blackwell_fmha_func(
@@ -554,7 +561,7 @@ def test_jagged_vs_padded_kv(
554561
v_unpad,
555562
cu_seqlens_q=cu_seqlens_q,
556563
cu_seqlens_k=cu_seqlens_k_jagged,
557-
max_seq_len_q=seqlen_q,
564+
max_seq_len_q=max_seqlen_q,
558565
max_seq_len_k=max_seqlen_k,
559566
causal=causal,
560567
window_size=window_size,
@@ -571,12 +578,17 @@ def test_jagged_vs_padded_kv(
571578
v_,
572579
cu_seqlens_q=cu_seqlens_q,
573580
cu_seqlens_k=cu_seqlens_k_padded,
574-
max_seq_len_q=seqlen_q,
581+
max_seq_len_q=max_seqlen_q,
575582
max_seq_len_k=max_seqlen_k,
576583
causal=causal,
577584
window_size=window_size,
578585
seqlen_kv=seqused_k,
579586
)
587+
if DEBUG:
588+
print(f"out_jagged: {out_jagged}")
589+
print(f"k_: {k_.shape}")
590+
print(f"v_: {v_.shape}")
591+
print(f"out_padded: {out_padded}")
580592

581593
# # Compare outputs
582594
diff = (out_jagged - out_padded).abs().max().item()

0 commit comments

Comments
 (0)