Skip to content

Commit 58928a3

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Small changes to improve blackwell_fmha_test.py (pytorch#4896)
Summary: X-link: facebookresearch/FBGEMM#1922 Add some BE features: * fix seed * increase backward test to 200 * decrease backward test verbosity * improve error message when assertion fails Reviewed By: q10 Differential Revision: D81992869
1 parent 53f9e51 commit 58928a3

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
from .test_utils import attention_ref, generate_qkv, generate_random_padding_mask
2222

2323
common_settings = {
24-
"verbosity": Verbosity.verbose,
25-
"max_examples": 20,
24+
"verbosity": Verbosity.normal,
25+
"max_examples": 200,
2626
"deadline": None,
2727
"suppress_health_check": [HealthCheck.filter_too_much, HealthCheck.data_too_large],
2828
}
2929

3030
DEBUG = False
31+
SEED = 2
3132

3233
compute_capability = (0, 0)
3334
if torch.cuda.is_available():
@@ -50,21 +51,39 @@ def _allclose(
5051
t_pt: torch.Tensor,
5152
) -> None:
5253
assert t_test.shape == t_ref.shape == t_pt.shape
54+
55+
ratio = 2.0
56+
57+
# Calculate all differences
58+
test_ref_diff = self._abs_max(t_test - t_ref)
59+
test_pt_diff = self._abs_max(t_test - t_pt)
60+
pt_ref_diff = self._abs_max(t_pt - t_ref)
61+
5362
if DEBUG:
5463
# Debug: Print the differences
64+
print(f"DEBUG: Max absolute difference vs ref: {test_ref_diff}")
65+
print(f"DEBUG: Max absolute difference vs pt: {test_pt_diff}")
66+
print(f"DEBUG: Max absolute difference pt vs ref: {pt_ref_diff}")
5567
print(
56-
f"DEBUG: Max absolute difference vs ref: {self._abs_max(t_test - t_ref)}"
57-
)
58-
print(
59-
f"DEBUG: Max absolute difference vs pt: {self._abs_max(t_test - t_pt)}"
60-
)
61-
print(
62-
f"DEBUG: Max absolute difference pt vs ref: {self._abs_max(t_pt - t_ref)}"
68+
f"DEBUG: Tolerance check: {test_ref_diff} <= {ratio * pt_ref_diff + 1e-5}"
6369
)
64-
print(
65-
f"DEBUG: Tolerance check: {self._abs_max(t_test - t_ref)} <= {2 * self._abs_max(t_pt - t_ref) + 1e-5}"
66-
)
67-
assert self._abs_max(t_test - t_ref) <= 2 * self._abs_max(t_pt - t_ref) + 1e-4
70+
71+
# First assertion with gap information
72+
tolerance_threshold = ratio * pt_ref_diff + 1e-4
73+
assert test_ref_diff <= tolerance_threshold, (
74+
f"Tolerance check failed: max_diff={test_ref_diff:.6f} > "
75+
f"threshold={tolerance_threshold:.6f}, gap={test_ref_diff - tolerance_threshold:.6f}"
76+
)
77+
78+
# sanity checks
79+
assert test_ref_diff <= 0.5, (
80+
f"Max difference vs ref too large: {test_ref_diff:.6f} > 0.5, "
81+
f"gap={test_ref_diff - 0.5:.6f}"
82+
)
83+
assert pt_ref_diff <= 0.5, (
84+
f"Max difference pt vs ref too large: {pt_ref_diff:.6f} > 0.5, "
85+
f"gap={pt_ref_diff - 0.5:.6f}"
86+
)
6887

6988
def _generate_qkv(
7089
self,
@@ -120,6 +139,7 @@ def _execute_cutlass_blackwell_attn_dense(
120139
) -> None:
121140
device = torch.accelerator.current_accelerator()
122141
assert device is not None
142+
torch.manual_seed(SEED)
123143
assert seqlen_q <= seqlen_k
124144
q, k, v = self._generate_qkv(
125145
batch_size,
@@ -247,6 +267,7 @@ def _execute_cutlass_blackwell_attn_varlen(
247267
) -> None:
248268
device = torch.accelerator.current_accelerator()
249269
assert device is not None
270+
torch.manual_seed(SEED)
250271
q_ref, k_ref, v_ref = self._generate_qkv(
251272
batch_size,
252273
seqlen_q,
@@ -472,6 +493,8 @@ def test_jagged_vs_padded_kv(
472493
head_dim = 128
473494
dtype = torch.bfloat16
474495

496+
torch.manual_seed(SEED)
497+
475498
# Create tensors
476499
q_padded = torch.randn(
477500
batch_size,

0 commit comments

Comments
 (0)