Skip to content

Commit 58a7a7e

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add softmax_scale to blackwell cutlass fmha
Summary: Adding softmax_scale plumbing to unblock D82490887 for tritonbench Differential Revision: D82788784
1 parent 53f9e51 commit 58a7a7e

File tree

5 files changed

+57
-9
lines changed

5 files changed

+57
-9
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _cutlass_blackwell_fmha_backward(
100100
cu_seqlens_k: torch.Tensor | None = None,
101101
max_seq_len_q: int | None = None,
102102
max_seq_len_k: int | None = None,
103+
softmax_scale: float | None = None,
103104
causal: bool = False,
104105
window_left: int = -1,
105106
window_right: int = -1,
@@ -123,6 +124,7 @@ def _cutlass_blackwell_fmha_backward(
123124
cu_seqlens_k=cu_seqlens_k,
124125
max_seq_len_q=max_seq_len_q,
125126
max_seq_len_k=max_seq_len_k,
127+
softmax_scale=softmax_scale,
126128
causal=causal,
127129
window_size_left=window_left,
128130
window_size_right=window_right,
@@ -274,6 +276,7 @@ def backward(ctx, dout: torch.Tensor, *args: Any) -> Tuple[ # type: ignore
274276
ctx.cu_seqlens_k,
275277
ctx.max_seq_len_q,
276278
ctx.max_seq_len_k,
279+
ctx.softmax_scale,
277280
ctx.causal,
278281
window_left,
279282
window_right,

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
1919
const std::optional<const at::Tensor>& cu_seqlens_k,
2020
std::optional<int> max_seq_len_q,
2121
std::optional<int> max_seq_len_k,
22+
const std::optional<double> softmax_scale,
2223
const int window_size_left,
2324
const int window_size_right
2425
) {
@@ -207,8 +208,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
207208
get<2, 1>(stride_dV) = 0;
208209
}
209210

210-
// TODO: pass in softmax_scale?
211-
ElementAccumulator softmax_scale = 1.0f / sqrtf(D);
211+
ElementAccumulator softmax_scale_value = softmax_scale.has_value() ? softmax_scale.value() : (1.0f / sqrtf(D));
212212

213213
at::Tensor dQ = torch::empty_like(q);
214214
at::Tensor dK = torch::empty_like(k);
@@ -253,7 +253,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
253253
stride_dK,
254254
static_cast<Element*>(dV.data_ptr()),
255255
stride_dV,
256-
softmax_scale,
256+
softmax_scale_value,
257257
dq_semaphore_ptr,
258258
window_size_left,
259259
window_size_right,
@@ -276,6 +276,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
276276
const std::optional<at::Tensor>& cu_seqlens_k,
277277
std::optional<int64_t> max_seq_len_q,
278278
std::optional<int64_t> max_seq_len_k,
279+
std::optional<double> softmax_scale,
279280
bool causal,
280281
int64_t window_size_left,
281282
int64_t window_size_right,
@@ -334,6 +335,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
334335
cu_seqlens_k,
335336
max_seq_len_q,
336337
max_seq_len_k,
338+
softmax_scale,
337339
window_size_left,
338340
window_size_right);
339341
};
@@ -419,6 +421,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
419421
" Tensor? cu_seqlens_k=None, "
420422
" int? max_seq_len_q=None, "
421423
" int? max_seq_len_k=None, "
424+
" float? softmax_scale=None, "
422425
" bool causal=False, "
423426
" int window_size_left=-1, "
424427
" int window_size_right=-1, "

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ std::tuple<at::Tensor, at::Tensor> fmha_fwd(
225225
static_cast<Element*>(v.data_ptr()), stride_V,
226226
window_size_left, window_size_right
227227
},
228-
0.0f /* softmax_scale */,
228+
static_cast<float>(softmax_scale.value_or(0.0f)) /* softmax_scale */,
229229
1.0f /* scale_q */,
230230
1.0f /* scale_k */,
231231
1.0f /* scale_v */,

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import random
88
import unittest
9-
from typing import Tuple
9+
from typing import Optional, Tuple
1010

1111
import hypothesis.strategies as st
1212
import torch
@@ -117,10 +117,15 @@ def _execute_cutlass_blackwell_attn_dense(
117117
window_size: tuple[int, int],
118118
fwd_only: bool,
119119
deterministic: bool,
120+
sm_scale: Optional[float],
120121
) -> None:
121122
device = torch.accelerator.current_accelerator()
122123
assert device is not None
123124
assert seqlen_q <= seqlen_k
125+
126+
# Initialize deterministic variables
127+
out_d = None
128+
124129
q, k, v = self._generate_qkv(
125130
batch_size,
126131
seqlen_q,
@@ -144,7 +149,13 @@ def _execute_cutlass_blackwell_attn_dense(
144149

145150
# Run reference attention
146151
out_baseline, _ = attention_ref(
147-
q, k, v, causal=causal, window_size=window_size, upcast=True
152+
q,
153+
k,
154+
v,
155+
causal=causal,
156+
window_size=window_size,
157+
upcast=True,
158+
softmax_scale=sm_scale,
148159
)
149160
if dtype == torch.float8_e4m3fn:
150161
# reference implementation only supports decode case (seqlen_q == 1)
@@ -161,6 +172,7 @@ def _execute_cutlass_blackwell_attn_dense(
161172
window_size=window_size,
162173
reorder_ops=True,
163174
upcast=False,
175+
softmax_scale=sm_scale,
164176
)
165177

166178
# Run tested kernel
@@ -172,6 +184,7 @@ def _execute_cutlass_blackwell_attn_dense(
172184
window_size=window_size,
173185
seqlen_kv=seqlen_kv,
174186
deterministic=deterministic,
187+
softmax_scale=sm_scale,
175188
)
176189
if DEBUG:
177190
print("cutlass_blackwell_fmha_func completed successfully!")
@@ -190,6 +203,7 @@ def _execute_cutlass_blackwell_attn_dense(
190203
window_size=window_size,
191204
seqlen_kv=seqlen_kv,
192205
deterministic=deterministic,
206+
softmax_scale=sm_scale,
193207
)
194208
assert torch.equal(out, out_d)
195209

@@ -244,9 +258,13 @@ def _execute_cutlass_blackwell_attn_varlen(
244258
window_size: tuple[int, int],
245259
fwd_only: bool,
246260
deterministic: bool,
261+
sm_scale: Optional[float],
247262
) -> None:
248263
device = torch.accelerator.current_accelerator()
249264
assert device is not None
265+
266+
# Initialize deterministic variables
267+
out_unpad_d = None
250268
q_ref, k_ref, v_ref = self._generate_qkv(
251269
batch_size,
252270
seqlen_q,
@@ -306,6 +324,7 @@ def _execute_cutlass_blackwell_attn_varlen(
306324
key_padding_mask,
307325
causal=causal,
308326
window_size=window_size,
327+
softmax_scale=sm_scale,
309328
)
310329

311330
out_pt, _ = attention_ref(
@@ -318,6 +337,7 @@ def _execute_cutlass_blackwell_attn_varlen(
318337
window_size=window_size,
319338
upcast=False,
320339
reorder_ops=True,
340+
softmax_scale=sm_scale,
321341
)
322342

323343
out_unpad = cutlass_blackwell_fmha_func(
@@ -331,6 +351,7 @@ def _execute_cutlass_blackwell_attn_varlen(
331351
max_seq_len_k=max_seqlen_k,
332352
window_size=window_size,
333353
deterministic=deterministic,
354+
softmax_scale=sm_scale,
334355
)
335356
out = output_pad_fn(out_unpad)
336357

@@ -351,6 +372,7 @@ def _execute_cutlass_blackwell_attn_varlen(
351372
max_seq_len_k=max_seqlen_k,
352373
window_size=window_size,
353374
deterministic=deterministic,
375+
softmax_scale=sm_scale,
354376
)
355377
out_d = output_pad_fn(out_unpad_d)
356378
assert torch.equal(out, out_d)
@@ -396,11 +418,13 @@ def _execute_cutlass_blackwell_attn_varlen(
396418
batch_size,
397419
is_mqa,
398420
window_size,
421+
sm_scale,
399422
)
400423
for seqlen_k in [64, 128, 256, 1024]
401424
for batch_size in [1, 2]
402425
for is_mqa in [True]
403426
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
427+
for sm_scale in [None, 1.0 / 128]
404428
]
405429
)
406430
def test_decode(
@@ -409,6 +433,7 @@ def test_decode(
409433
batch_size: int,
410434
is_mqa: bool,
411435
window_size: tuple[int, int],
436+
sm_scale: Optional[float],
412437
q_heads: int = 8,
413438
dtype: torch.dtype = torch.float8_e4m3fn,
414439
) -> None:
@@ -429,6 +454,7 @@ def test_decode(
429454
window_size=window_size,
430455
fwd_only=True,
431456
deterministic=False,
457+
sm_scale=sm_scale,
432458
)
433459

434460
@skip_cuda_lt_sm100
@@ -441,12 +467,14 @@ def test_decode(
441467
q_heads,
442468
causal,
443469
window_size,
470+
sm_scale,
444471
)
445472
for kv_padding in [128, 256, 512, 1024]
446473
for batch_size in [2, 8]
447474
for q_heads in [8, 16]
448475
for causal in [True, False]
449476
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
477+
for sm_scale in [None, 1.0 / 128]
450478
]
451479
)
452480
def test_jagged_vs_padded_kv(
@@ -455,7 +483,8 @@ def test_jagged_vs_padded_kv(
455483
batch_size: int,
456484
q_heads: int,
457485
causal: bool,
458-
window_size: tuple[int, int] = (-1, -1),
486+
window_size: tuple[int, int],
487+
sm_scale: Optional[float],
459488
) -> None:
460489
"""
461490
Test comparing two scenarios:
@@ -565,6 +594,7 @@ def test_jagged_vs_padded_kv(
565594
max_seq_len_k=max_seqlen_k,
566595
causal=causal,
567596
window_size=window_size,
597+
softmax_scale=sm_scale,
568598
)
569599

570600
# # Scenario B: Padded KV with seqlen_kv
@@ -583,6 +613,7 @@ def test_jagged_vs_padded_kv(
583613
causal=causal,
584614
window_size=window_size,
585615
seqlen_kv=seqused_k,
616+
softmax_scale=sm_scale,
586617
)
587618
if DEBUG:
588619
print(f"out_jagged: {out_jagged}")
@@ -611,6 +642,7 @@ def test_jagged_vs_padded_kv(
611642
is_varlen,
612643
kv_heads,
613644
window_size,
645+
sm_scale,
614646
)
615647
for seqlen_q, offset_q in [
616648
(101, 0),
@@ -629,6 +661,7 @@ def test_jagged_vs_padded_kv(
629661
for is_varlen in [False, True]
630662
for kv_heads in [1, 2, 3, 4]
631663
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
664+
for sm_scale in [None, 1.0 / 128]
632665
]
633666
)
634667
def test_forward(
@@ -641,6 +674,7 @@ def test_forward(
641674
is_varlen: bool,
642675
kv_heads: int,
643676
window_size: tuple[int, int],
677+
sm_scale: Optional[float],
644678
dtype: torch.dtype = torch.bfloat16,
645679
) -> None:
646680
seqlen_k = offset_q + seqlen_q
@@ -664,6 +698,7 @@ def test_forward(
664698
window_size=window_size,
665699
fwd_only=True,
666700
deterministic=False,
701+
sm_scale=sm_scale,
667702
)
668703

669704
@skip_cuda_lt_sm100
@@ -680,6 +715,7 @@ def test_forward(
680715
[(-1, -1), (128, 0), (256, 0), (128, 128), (512, 0)]
681716
),
682717
deterministic=st.booleans(),
718+
sm_scale=st.sampled_from([None, 1.0 / 128]),
683719
)
684720
@settings(**common_settings)
685721
def test_backward(
@@ -693,6 +729,7 @@ def test_backward(
693729
is_gqa: bool,
694730
window_size: tuple[int, int],
695731
deterministic: bool,
732+
sm_scale: Optional[float],
696733
) -> None:
697734
test_func = (
698735
self._execute_cutlass_blackwell_attn_varlen
@@ -712,4 +749,5 @@ def test_backward(
712749
window_size=window_size,
713750
fwd_only=False,
714751
deterministic=deterministic,
752+
sm_scale=sm_scale,
715753
)

fbgemm_gpu/experimental/gen_ai/test/attention/test_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def attention_ref( # noqa
243243
upcast=True,
244244
reorder_ops=False,
245245
key_leftpad=None,
246+
softmax_scale=None,
246247
):
247248
"""
248249
Arguments:
@@ -261,6 +262,7 @@ def attention_ref( # noqa
261262
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
262263
without changing the math. This is to estimate the numerical error from operation
263264
reordering.
265+
softmax_scale: float, scale for softmax. If None, use 1/sqrt(head_dim)
264266
Output:
265267
output: (batch_size, seqlen_q, nheads, head_dim)
266268
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
@@ -274,10 +276,12 @@ def attention_ref( # noqa
274276
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
275277
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
276278
d = q.shape[-1]
279+
# Use provided softmax_scale or default to 1/sqrt(d)
280+
scale = softmax_scale if softmax_scale is not None else 1 / math.sqrt(d)
277281
if not reorder_ops:
278-
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
282+
scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
279283
else:
280-
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
284+
scores = torch.einsum("bthd,bshd->bhts", q, k * scale)
281285
if softcap > 0:
282286
scores /= softcap
283287
scores = scores.tanh()

0 commit comments

Comments
 (0)