6
6
7
7
import random
8
8
import unittest
9
- from typing import Tuple
9
+ from typing import Optional , Tuple
10
10
11
11
import hypothesis .strategies as st
12
12
import torch
@@ -117,10 +117,15 @@ def _execute_cutlass_blackwell_attn_dense(
117
117
window_size : tuple [int , int ],
118
118
fwd_only : bool ,
119
119
deterministic : bool ,
120
+ sm_scale : Optional [float ],
120
121
) -> None :
121
122
device = torch .accelerator .current_accelerator ()
122
123
assert device is not None
123
124
assert seqlen_q <= seqlen_k
125
+
126
+ # Initialize deterministic variables
127
+ out_d = None
128
+
124
129
q , k , v = self ._generate_qkv (
125
130
batch_size ,
126
131
seqlen_q ,
@@ -144,7 +149,13 @@ def _execute_cutlass_blackwell_attn_dense(
144
149
145
150
# Run reference attention
146
151
out_baseline , _ = attention_ref (
147
- q , k , v , causal = causal , window_size = window_size , upcast = True
152
+ q ,
153
+ k ,
154
+ v ,
155
+ causal = causal ,
156
+ window_size = window_size ,
157
+ upcast = True ,
158
+ softmax_scale = sm_scale ,
148
159
)
149
160
if dtype == torch .float8_e4m3fn :
150
161
# reference implementation only supports decode case (seqlen_q == 1)
@@ -161,6 +172,7 @@ def _execute_cutlass_blackwell_attn_dense(
161
172
window_size = window_size ,
162
173
reorder_ops = True ,
163
174
upcast = False ,
175
+ softmax_scale = sm_scale ,
164
176
)
165
177
166
178
# Run tested kernel
@@ -172,6 +184,7 @@ def _execute_cutlass_blackwell_attn_dense(
172
184
window_size = window_size ,
173
185
seqlen_kv = seqlen_kv ,
174
186
deterministic = deterministic ,
187
+ softmax_scale = sm_scale ,
175
188
)
176
189
if DEBUG :
177
190
print ("cutlass_blackwell_fmha_func completed successfully!" )
@@ -190,6 +203,7 @@ def _execute_cutlass_blackwell_attn_dense(
190
203
window_size = window_size ,
191
204
seqlen_kv = seqlen_kv ,
192
205
deterministic = deterministic ,
206
+ softmax_scale = sm_scale ,
193
207
)
194
208
assert torch .equal (out , out_d )
195
209
@@ -244,9 +258,13 @@ def _execute_cutlass_blackwell_attn_varlen(
244
258
window_size : tuple [int , int ],
245
259
fwd_only : bool ,
246
260
deterministic : bool ,
261
+ sm_scale : Optional [float ],
247
262
) -> None :
248
263
device = torch .accelerator .current_accelerator ()
249
264
assert device is not None
265
+
266
+ # Initialize deterministic variables
267
+ out_unpad_d = None
250
268
q_ref , k_ref , v_ref = self ._generate_qkv (
251
269
batch_size ,
252
270
seqlen_q ,
@@ -306,6 +324,7 @@ def _execute_cutlass_blackwell_attn_varlen(
306
324
key_padding_mask ,
307
325
causal = causal ,
308
326
window_size = window_size ,
327
+ softmax_scale = sm_scale ,
309
328
)
310
329
311
330
out_pt , _ = attention_ref (
@@ -318,6 +337,7 @@ def _execute_cutlass_blackwell_attn_varlen(
318
337
window_size = window_size ,
319
338
upcast = False ,
320
339
reorder_ops = True ,
340
+ softmax_scale = sm_scale ,
321
341
)
322
342
323
343
out_unpad = cutlass_blackwell_fmha_func (
@@ -331,6 +351,7 @@ def _execute_cutlass_blackwell_attn_varlen(
331
351
max_seq_len_k = max_seqlen_k ,
332
352
window_size = window_size ,
333
353
deterministic = deterministic ,
354
+ softmax_scale = sm_scale ,
334
355
)
335
356
out = output_pad_fn (out_unpad )
336
357
@@ -351,6 +372,7 @@ def _execute_cutlass_blackwell_attn_varlen(
351
372
max_seq_len_k = max_seqlen_k ,
352
373
window_size = window_size ,
353
374
deterministic = deterministic ,
375
+ softmax_scale = sm_scale ,
354
376
)
355
377
out_d = output_pad_fn (out_unpad_d )
356
378
assert torch .equal (out , out_d )
@@ -396,11 +418,13 @@ def _execute_cutlass_blackwell_attn_varlen(
396
418
batch_size ,
397
419
is_mqa ,
398
420
window_size ,
421
+ sm_scale ,
399
422
)
400
423
for seqlen_k in [64 , 128 , 256 , 1024 ]
401
424
for batch_size in [1 , 2 ]
402
425
for is_mqa in [True ]
403
426
for window_size in [(- 1 , - 1 ), (0 , 0 ), (0 , 128 ), (128 , 0 ), (1024 , 0 )]
427
+ for sm_scale in [None , 1.0 / 128 ]
404
428
]
405
429
)
406
430
def test_decode (
@@ -409,6 +433,7 @@ def test_decode(
409
433
batch_size : int ,
410
434
is_mqa : bool ,
411
435
window_size : tuple [int , int ],
436
+ sm_scale : Optional [float ],
412
437
q_heads : int = 8 ,
413
438
dtype : torch .dtype = torch .float8_e4m3fn ,
414
439
) -> None :
@@ -429,6 +454,7 @@ def test_decode(
429
454
window_size = window_size ,
430
455
fwd_only = True ,
431
456
deterministic = False ,
457
+ sm_scale = sm_scale ,
432
458
)
433
459
434
460
@skip_cuda_lt_sm100
@@ -441,12 +467,14 @@ def test_decode(
441
467
q_heads ,
442
468
causal ,
443
469
window_size ,
470
+ sm_scale ,
444
471
)
445
472
for kv_padding in [128 , 256 , 512 , 1024 ]
446
473
for batch_size in [2 , 8 ]
447
474
for q_heads in [8 , 16 ]
448
475
for causal in [True , False ]
449
476
for window_size in [(- 1 , - 1 ), (0 , 0 ), (0 , 128 ), (128 , 0 ), (1024 , 0 )]
477
+ for sm_scale in [None , 1.0 / 128 ]
450
478
]
451
479
)
452
480
def test_jagged_vs_padded_kv (
@@ -455,7 +483,8 @@ def test_jagged_vs_padded_kv(
455
483
batch_size : int ,
456
484
q_heads : int ,
457
485
causal : bool ,
458
- window_size : tuple [int , int ] = (- 1 , - 1 ),
486
+ window_size : tuple [int , int ],
487
+ sm_scale : Optional [float ],
459
488
) -> None :
460
489
"""
461
490
Test comparing two scenarios:
@@ -565,6 +594,7 @@ def test_jagged_vs_padded_kv(
565
594
max_seq_len_k = max_seqlen_k ,
566
595
causal = causal ,
567
596
window_size = window_size ,
597
+ softmax_scale = sm_scale ,
568
598
)
569
599
570
600
# # Scenario B: Padded KV with seqlen_kv
@@ -583,6 +613,7 @@ def test_jagged_vs_padded_kv(
583
613
causal = causal ,
584
614
window_size = window_size ,
585
615
seqlen_kv = seqused_k ,
616
+ softmax_scale = sm_scale ,
586
617
)
587
618
if DEBUG :
588
619
print (f"out_jagged: { out_jagged } " )
@@ -611,6 +642,7 @@ def test_jagged_vs_padded_kv(
611
642
is_varlen ,
612
643
kv_heads ,
613
644
window_size ,
645
+ sm_scale ,
614
646
)
615
647
for seqlen_q , offset_q in [
616
648
(101 , 0 ),
@@ -629,6 +661,7 @@ def test_jagged_vs_padded_kv(
629
661
for is_varlen in [False , True ]
630
662
for kv_heads in [1 , 2 , 3 , 4 ]
631
663
for window_size in [(- 1 , - 1 ), (0 , 0 ), (0 , 128 ), (128 , 0 ), (1024 , 0 )]
664
+ for sm_scale in [None , 1.0 / 128 ]
632
665
]
633
666
)
634
667
def test_forward (
@@ -641,6 +674,7 @@ def test_forward(
641
674
is_varlen : bool ,
642
675
kv_heads : int ,
643
676
window_size : tuple [int , int ],
677
+ sm_scale : Optional [float ],
644
678
dtype : torch .dtype = torch .bfloat16 ,
645
679
) -> None :
646
680
seqlen_k = offset_q + seqlen_q
@@ -664,6 +698,7 @@ def test_forward(
664
698
window_size = window_size ,
665
699
fwd_only = True ,
666
700
deterministic = False ,
701
+ sm_scale = sm_scale ,
667
702
)
668
703
669
704
@skip_cuda_lt_sm100
@@ -680,6 +715,7 @@ def test_forward(
680
715
[(- 1 , - 1 ), (128 , 0 ), (256 , 0 ), (128 , 128 ), (512 , 0 )]
681
716
),
682
717
deterministic = st .booleans (),
718
+ sm_scale = st .sampled_from ([None , 1.0 / 128 ]),
683
719
)
684
720
@settings (** common_settings )
685
721
def test_backward (
@@ -693,6 +729,7 @@ def test_backward(
693
729
is_gqa : bool ,
694
730
window_size : tuple [int , int ],
695
731
deterministic : bool ,
732
+ sm_scale : Optional [float ],
696
733
) -> None :
697
734
test_func = (
698
735
self ._execute_cutlass_blackwell_attn_varlen
@@ -712,4 +749,5 @@ def test_backward(
712
749
window_size = window_size ,
713
750
fwd_only = False ,
714
751
deterministic = deterministic ,
752
+ sm_scale = sm_scale ,
715
753
)
0 commit comments