Skip to content

Commit a14de4d

Browse files
Remove unsupported parameters from decode kernel tests (#4911)
Summary: X-link: facebookresearch/FBGEMM#1936 D80992628 introduced SWA FWD kernel changes which did not support decode kernels (i.e., supporting sm100_fmha_fwd but not sm100_fmha_gen). Similarly, softmax_scale introduced in D82788784 did not support decode kernels either. In blackwell_fmha_test, the these parameters are dropped during decode kernel selection (https://www.internalfb.com/code/fbsource/[cd7066706035]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py?lines=182) To avoid confusion, do not test test_decode with ignored parameters. Reviewed By: Aya-ZIbra, sryap Differential Revision: D82991496
1 parent 03f6bde commit a14de4d

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,23 +439,17 @@ def _execute_cutlass_blackwell_attn_varlen(
439439
seqlen_k,
440440
batch_size,
441441
is_mqa,
442-
window_size,
443-
sm_scale,
444442
)
445443
for seqlen_k in [64, 128, 256, 1024]
446444
for batch_size in [1, 2]
447445
for is_mqa in [True]
448-
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
449-
for sm_scale in [None, 1.0 / 128]
450446
]
451447
)
452448
def test_decode(
453449
self,
454450
seqlen_k: int,
455451
batch_size: int,
456452
is_mqa: bool,
457-
window_size: tuple[int, int],
458-
sm_scale: Optional[float],
459453
q_heads: int = 8,
460454
dtype: torch.dtype = torch.float8_e4m3fn,
461455
) -> None:
@@ -473,10 +467,12 @@ def test_decode(
473467
head_dim=128,
474468
dtype=dtype,
475469
causal=causal,
476-
window_size=window_size,
470+
# Decode kernel does not support sliding window attention yet
471+
window_size=(-1, -1),
477472
fwd_only=True,
478473
deterministic=False,
479-
sm_scale=sm_scale,
474+
# Decode kernel does not support sm_scale
475+
sm_scale=None,
480476
)
481477

482478
@skip_cuda_lt_sm100

0 commit comments

Comments
 (0)