Skip to content

Commit ffad92c

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add head_dim = 64 in B200 Attention.
Summary: X-link: facebookresearch/FBGEMM#1942 This diff adds support for head_dim = 64 in the B200 Attention module. Reviewed By: sryap Differential Revision: D82996471
1 parent c1f22a9 commit ffad92c

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
template <
66
typename Element,
77
typename ActiveMask,
8+
int HeadDim,
89
bool kIsVarlen,
910
bool kIsDeterministic,
1011
class... KernelOptions>
@@ -34,8 +35,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
3435
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
3536
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
3637
>;
37-
38-
using TileShape = Shape<_128, _128, _128>;
38+
using D_H = cute::Int<HeadDim>;
39+
using TileShape = Shape<_128, _128, D_H>;
3940

4041
using Operation = cutlass::fmha::device::
4142
Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /*kIsMla=*/false, ActiveMask, kIsDeterministic>;
@@ -114,7 +115,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
114115
problem_shape = cute::make_tuple(
115116
Q, K, D, D, make_shape(make_shape(H_R, H_K), B));
116117
}
117-
118+
TORCH_CHECK(D == HeadDim);
118119
TORCH_CHECK(D % 8 == 0); // Alignment
119120
if constexpr (!kIsVarlen) {
120121
// TODO: support Q < 8
@@ -314,13 +315,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
314315
[&](
315316
auto element,
316317
auto element_out,
318+
auto head_dim,
317319
auto varlen,
318320
auto deterministic,
319321
auto mask,
320322
auto... kernel_options) {
321323
return fmha_bwd<
322324
decltype(element),
323325
decltype(mask),
326+
head_dim,
324327
varlen,
325328
deterministic,
326329
decltype(kernel_options)...>
@@ -340,51 +343,63 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
340343
window_size_right);
341344
};
342345

343-
auto dispatch_type = [&](auto varlen, auto deterministic, auto mask) {
346+
auto dispatch_type = [&](auto varlen, auto deterministic, auto mask, auto head_dim) {
344347
if (query.dtype() == torch::kFloat16) {
345348
return dispatch_fmha(
346-
cutlass::half_t{}, cutlass::half_t{}, varlen, deterministic, mask);
349+
cutlass::half_t{}, cutlass::half_t{}, head_dim, varlen, deterministic, mask);
347350
}
348351
else if (query.dtype() == torch::kBFloat16) {
349352
return dispatch_fmha(
350-
cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, varlen, deterministic, mask);
353+
cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, head_dim, varlen, deterministic, mask);
351354
}
352355
else if (query.dtype() == torch::kFloat8_e4m3fn) {
353356
return dispatch_fmha(
354-
cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, varlen, deterministic, mask);
357+
cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, head_dim, varlen, deterministic, mask);
355358
}
356359
TORCH_CHECK(false, "Unsupported dtype for q: ", query.dtype());
357360
};
358361

362+
auto dispatch_head_dim = [&](auto varlen, auto deterministic, auto mask) {
363+
if (query.size(query.dim() - 1) == 128) {
364+
return dispatch_type(varlen, deterministic, mask, std::integral_constant<int, 128>{});
365+
}
366+
else if (query.size(query.dim() - 1) == 64) {
367+
return dispatch_type(varlen, deterministic, mask, std::integral_constant<int, 64>{});
368+
}
369+
else {
370+
TORCH_CHECK(false, "Unsupported head dim: ", query.size(query.dim() - 1));
371+
}
372+
};
373+
359374
auto dispatch_mask = [&](auto varlen, auto deterministic) {
360375
if (causal) {
361376
if (bottom_right) {
362-
return dispatch_type(
377+
return dispatch_head_dim(
363378
varlen, deterministic, CausalForBackwardMask</*kIsQBegin=*/false>{});
364379
}
365380
else {
366-
return dispatch_type(
381+
return dispatch_head_dim(
367382
varlen, deterministic, CausalForBackwardMask</*kIsQBegin=*/true>{});
368383
}
369384
}
370385
else if (local) {
371386
if (bottom_right) {
372-
return dispatch_type(
387+
return dispatch_head_dim(
373388
varlen, deterministic, LocalMaskForBackward</*kIsQBegin=*/false>{});
374389
}
375390
else {
376-
return dispatch_type(
391+
return dispatch_head_dim(
377392
varlen, deterministic, LocalMaskForBackward</*kIsQBegin=*/true>{});
378393
}
379394
}
380395
else if (varlen || key.size(1) % 128 != 0) {
381396
// Use the residual mask for varlen or when K seqlen is not multiple of
382397
// blockN
383-
return dispatch_type(
398+
return dispatch_head_dim(
384399
varlen, deterministic, ResidualMaskForBackward{});
385400
}
386401
else {
387-
return dispatch_type(
402+
return dispatch_head_dim(
388403
varlen, deterministic, NoMask{});
389404
}
390405
};

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
320320
if (q.size(q.dim() - 1) == 128) {
321321
return dispatch_type(varlen, mask, std::integral_constant<int, 128>{});
322322
}
323+
else if (q.size(q.dim() - 1) == 64) {
324+
return dispatch_type(varlen, mask, std::integral_constant<int, 64>{});
325+
}
323326
else {
324327
TORCH_CHECK(false, "Unsupported head dim: ", q.size(q.dim() - 1));
325328
}

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,17 +440,26 @@ def _execute_cutlass_blackwell_attn_varlen(
440440
seqlen_k,
441441
batch_size,
442442
is_mqa,
443+
window_size,
444+
head_dim,
445+
sm_scale,
443446
)
444447
for seqlen_k in [64, 128, 256, 1024]
445448
for batch_size in [1, 2]
446449
for is_mqa in [True]
450+
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
451+
for head_dim in [128]
452+
for sm_scale in [None, 1.0 / head_dim]
447453
]
448454
)
449455
def test_decode(
450456
self,
451457
seqlen_k: int,
452458
batch_size: int,
453459
is_mqa: bool,
460+
window_size: tuple[int, int],
461+
head_dim: int,
462+
sm_scale: Optional[float],
454463
q_heads: int = 8,
455464
dtype: torch.dtype = torch.float8_e4m3fn,
456465
) -> None:
@@ -465,7 +474,7 @@ def test_decode(
465474
seqlen_k,
466475
q_heads,
467476
kv_heads=1 if is_mqa else q_heads,
468-
head_dim=128,
477+
head_dim=head_dim,
469478
dtype=dtype,
470479
causal=causal,
471480
# Decode kernel does not support sliding window attention yet
@@ -486,14 +495,16 @@ def test_decode(
486495
q_heads,
487496
causal,
488497
window_size,
498+
head_dim,
489499
sm_scale,
490500
)
491501
for kv_padding in [128, 256, 512, 1024]
492502
for batch_size in [2, 8]
493503
for q_heads in [8, 16]
494504
for causal in [True, False]
495505
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
496-
for sm_scale in [None, 1.0 / 128]
506+
for head_dim in [128]
507+
for sm_scale in [None, 1.0 / head_dim]
497508
]
498509
)
499510
def test_jagged_vs_padded_kv(
@@ -503,6 +514,7 @@ def test_jagged_vs_padded_kv(
503514
q_heads: int,
504515
causal: bool,
505516
window_size: tuple[int, int],
517+
head_dim: int,
506518
sm_scale: Optional[float],
507519
) -> None:
508520
"""
@@ -517,7 +529,7 @@ def test_jagged_vs_padded_kv(
517529
seqlen_q = kv_padding # Maximum sequence length (padded size)
518530
device = torch.accelerator.current_accelerator()
519531
kv_heads = 1
520-
head_dim = 128
532+
head_dim = head_dim
521533
dtype = torch.bfloat16
522534

523535
torch.manual_seed(SEED)
@@ -663,6 +675,7 @@ def test_jagged_vs_padded_kv(
663675
is_varlen,
664676
kv_heads,
665677
window_size,
678+
head_dim,
666679
sm_scale,
667680
)
668681
for seqlen_q, offset_q in [
@@ -682,7 +695,8 @@ def test_jagged_vs_padded_kv(
682695
for is_varlen in [False, True]
683696
for kv_heads in [1, 2, 3, 4]
684697
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
685-
for sm_scale in [None, 1.0 / 128]
698+
for head_dim in [64, 128]
699+
for sm_scale in [None, 1.0 / head_dim]
686700
]
687701
)
688702
def test_forward(
@@ -695,6 +709,7 @@ def test_forward(
695709
is_varlen: bool,
696710
kv_heads: int,
697711
window_size: tuple[int, int],
712+
head_dim: int,
698713
sm_scale: Optional[float],
699714
dtype: torch.dtype = torch.bfloat16,
700715
) -> None:
@@ -713,7 +728,7 @@ def test_forward(
713728
seqlen_k,
714729
q_heads=kv_heads * q_heads_per_kv_head,
715730
kv_heads=kv_heads,
716-
head_dim=128,
731+
head_dim=head_dim,
717732
dtype=dtype,
718733
causal=causal,
719734
window_size=window_size,
@@ -736,7 +751,8 @@ def test_forward(
736751
[(-1, -1), (128, 0), (256, 0), (128, 128), (512, 0)]
737752
),
738753
deterministic=st.booleans(),
739-
sm_scale=st.sampled_from([None, 1.0 / 128]),
754+
head_dim=st.sampled_from([64, 128]),
755+
is_sm_scale=st.booleans(),
740756
)
741757
@settings(**common_settings)
742758
def test_backward(
@@ -750,8 +766,10 @@ def test_backward(
750766
is_gqa: bool,
751767
window_size: tuple[int, int],
752768
deterministic: bool,
753-
sm_scale: Optional[float],
769+
head_dim: int,
770+
is_sm_scale: bool,
754771
) -> None:
772+
sm_scale = 1.0 / head_dim if is_sm_scale else None
755773
test_func = (
756774
self._execute_cutlass_blackwell_attn_varlen
757775
if is_varlen
@@ -764,7 +782,7 @@ def test_backward(
764782
seqlen,
765783
q_heads=kv_heads * q_heads_per_kv_head,
766784
kv_heads=kv_heads,
767-
head_dim=128,
785+
head_dim=head_dim,
768786
dtype=dtype,
769787
causal=causal,
770788
window_size=window_size,

0 commit comments

Comments
 (0)