21
21
from .test_utils import attention_ref , generate_qkv , generate_random_padding_mask
22
22
23
23
common_settings = {
24
- "verbosity" : Verbosity .verbose ,
25
- "max_examples" : 20 ,
24
+ "verbosity" : Verbosity .normal ,
25
+ "max_examples" : 200 ,
26
26
"deadline" : None ,
27
27
"suppress_health_check" : [HealthCheck .filter_too_much , HealthCheck .data_too_large ],
28
28
}
29
29
30
30
DEBUG = False
31
+ SEED = 2
31
32
32
33
compute_capability = (0 , 0 )
33
34
if torch .cuda .is_available ():
@@ -50,21 +51,39 @@ def _allclose(
50
51
t_pt : torch .Tensor ,
51
52
) -> None :
52
53
assert t_test .shape == t_ref .shape == t_pt .shape
54
+
55
+ ratio = 2.0
56
+
57
+ # Calculate all differences
58
+ test_ref_diff = self ._abs_max (t_test - t_ref )
59
+ test_pt_diff = self ._abs_max (t_test - t_pt )
60
+ pt_ref_diff = self ._abs_max (t_pt - t_ref )
61
+
53
62
if DEBUG :
54
63
# Debug: Print the differences
64
+ print (f"DEBUG: Max absolute difference vs ref: { test_ref_diff } " )
65
+ print (f"DEBUG: Max absolute difference vs pt: { test_pt_diff } " )
66
+ print (f"DEBUG: Max absolute difference pt vs ref: { pt_ref_diff } " )
55
67
print (
56
- f"DEBUG: Max absolute difference vs ref: { self ._abs_max (t_test - t_ref )} "
57
- )
58
- print (
59
- f"DEBUG: Max absolute difference vs pt: { self ._abs_max (t_test - t_pt )} "
60
- )
61
- print (
62
- f"DEBUG: Max absolute difference pt vs ref: { self ._abs_max (t_pt - t_ref )} "
68
+ f"DEBUG: Tolerance check: { test_ref_diff } <= { ratio * pt_ref_diff + 1e-5 } "
63
69
)
64
- print (
65
- f"DEBUG: Tolerance check: { self ._abs_max (t_test - t_ref )} <= { 2 * self ._abs_max (t_pt - t_ref ) + 1e-5 } "
66
- )
67
- assert self ._abs_max (t_test - t_ref ) <= 2 * self ._abs_max (t_pt - t_ref ) + 1e-4
70
+
71
+ # First assertion with gap information
72
+ tolerance_threshold = ratio * pt_ref_diff + 1e-4
73
+ assert test_ref_diff <= tolerance_threshold , (
74
+ f"Tolerance check failed: max_diff={ test_ref_diff :.6f} > "
75
+ f"threshold={ tolerance_threshold :.6f} , gap={ test_ref_diff - tolerance_threshold :.6f} "
76
+ )
77
+
78
+ # sanity checks
79
+ assert test_ref_diff <= 0.5 , (
80
+ f"Max difference vs ref too large: { test_ref_diff :.6f} > 0.5, "
81
+ f"gap={ test_ref_diff - 0.5 :.6f} "
82
+ )
83
+ assert pt_ref_diff <= 0.5 , (
84
+ f"Max difference pt vs ref too large: { pt_ref_diff :.6f} > 0.5, "
85
+ f"gap={ pt_ref_diff - 0.5 :.6f} "
86
+ )
68
87
69
88
def _generate_qkv (
70
89
self ,
@@ -120,6 +139,7 @@ def _execute_cutlass_blackwell_attn_dense(
120
139
) -> None :
121
140
device = torch .accelerator .current_accelerator ()
122
141
assert device is not None
142
+ torch .manual_seed (SEED )
123
143
assert seqlen_q <= seqlen_k
124
144
q , k , v = self ._generate_qkv (
125
145
batch_size ,
@@ -247,6 +267,7 @@ def _execute_cutlass_blackwell_attn_varlen(
247
267
) -> None :
248
268
device = torch .accelerator .current_accelerator ()
249
269
assert device is not None
270
+ torch .manual_seed (SEED )
250
271
q_ref , k_ref , v_ref = self ._generate_qkv (
251
272
batch_size ,
252
273
seqlen_q ,
@@ -472,6 +493,8 @@ def test_jagged_vs_padded_kv(
472
493
head_dim = 128
473
494
dtype = torch .bfloat16
474
495
496
+ torch .manual_seed (SEED )
497
+
475
498
# Create tensors
476
499
q_padded = torch .randn (
477
500
batch_size ,
0 commit comments