@@ -546,8 +546,6 @@ def get_benchmark(
546
546
providers_filter : Optional [list [str ]] = None ,
547
547
fa_kernel_mode = 'fwd' ,
548
548
attn_fwd = _attn_fwd_with_block_pointers ,
549
- xetla_assert_result = False ,
550
- xetla_warn_mismatch = False ,
551
549
):
552
550
"""
553
551
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -649,30 +647,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
649
647
elif provider == 'xetla' :
650
648
xetla_fn = None
651
649
if MODE == 'fwd' :
652
- module_name = f'flash_attn_causal_{ CAUSAL } ' .lower ()
653
- func = getattr (xetla_kernel , module_name )
654
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
655
- size_score = Z * H * N_CTX * N_CTX
656
- size_attn_mask = Z * N_CTX * N_CTX
657
- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
658
- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
659
- size_ml = Z * H * N_CTX
660
- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
661
- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
662
-
663
- def xetla_fwd_fn ():
664
- func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
665
- return out
666
-
667
- xetla_fn = xetla_fwd_fn
668
-
669
- def check_xetla_fwd_result ():
670
- if xetla_assert_result :
671
- benchmark_suite .assert_close (xetla_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'xetla to torch' )
672
- elif xetla_warn_mismatch :
673
- check_close (xetla_fn , torch_fn , atol , 1e-3 )
674
-
675
- check_xetla_fwd_result ()
650
+ min_ms = float ('nan' )
651
+ max_ms = float ('nan' )
652
+ mean = float ('nan' )
653
+ cv = float ('nan' )
676
654
677
655
if MODE == 'bwd' :
678
656
module_name = f'flash_attn_bwd_causal_{ CAUSAL } ' .lower ()
@@ -711,8 +689,6 @@ def xetla_bwd_fn():
711
689
)
712
690
713
691
elif provider == 'cutlass' :
714
- cutlass_fn = None
715
-
716
692
if MODE == 'fwd' :
717
693
name = 'attention'
718
694
func = getattr (cutlass_kernel , name )
@@ -723,17 +699,15 @@ def cutlass_fwd_fn():
723
699
return out
724
700
725
701
benchmark_suite .assert_close (cutlass_fwd_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'cutlass to torch' )
726
- cutlass_fn = cutlass_fwd_fn
727
702
728
703
_ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
729
- cutlass_fn ,
704
+ cutlass_fwd_fn ,
730
705
n_warmup = 10 ,
731
706
n_repeat = 10 ,
732
707
quantiles = quantiles ,
733
708
)
734
709
735
710
else :
736
- cutlass_fn = None
737
711
min_ms = float ('nan' )
738
712
max_ms = float ('nan' )
739
713
mean = float ('nan' )
@@ -757,7 +731,5 @@ def cutlass_fwd_fn():
757
731
if __name__ == '__main__' :
758
732
_benchmark = get_benchmark (
759
733
fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' ),
760
- xetla_assert_result = (os .getenv ('XETLA_ASSERT_RESULT' , '0' ) == '1' ),
761
- xetla_warn_mismatch = (os .getenv ('XETLA_WARN_MISMATCH' , '0' ) == '1' ),
762
734
)
763
735
_benchmark .run (show_plots = False , print_data = True )
0 commit comments