Skip to content

Commit 5d18616

Browse files
[FlashAttention] Remove XeTLA for fwd mode
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent 8b73e5a commit 5d18616

File tree

3 files changed

+5
-41
lines changed

3 files changed

+5
-41
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ jobs:
276276
277277
source ../../scripts/capture-hw-details.sh
278278
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
279-
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
280279
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
281280
282281
- name: Run Triton FA bwd kernel benchmark
@@ -302,7 +301,6 @@ jobs:
302301
303302
source ../../scripts/capture-hw-details.sh
304303
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
305-
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
306304
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
307305
308306
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,6 @@ def get_benchmark(
546546
providers_filter: Optional[list[str]] = None,
547547
fa_kernel_mode='fwd',
548548
attn_fwd=_attn_fwd_with_block_pointers,
549-
xetla_assert_result=False,
550-
xetla_warn_mismatch=False,
551549
):
552550
"""
553551
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -649,30 +647,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
649647
elif provider == 'xetla':
650648
xetla_fn = None
651649
if MODE == 'fwd':
652-
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
653-
func = getattr(xetla_kernel, module_name)
654-
out = torch.empty_like(q, device='xpu', dtype=dtype)
655-
size_score = Z * H * N_CTX * N_CTX
656-
size_attn_mask = Z * N_CTX * N_CTX
657-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
658-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
659-
size_ml = Z * H * N_CTX
660-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
661-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
662-
663-
def xetla_fwd_fn():
664-
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
665-
return out
666-
667-
xetla_fn = xetla_fwd_fn
668-
669-
def check_xetla_fwd_result():
670-
if xetla_assert_result:
671-
benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
672-
elif xetla_warn_mismatch:
673-
check_close(xetla_fn, torch_fn, atol, 1e-3)
674-
675-
check_xetla_fwd_result()
650+
min_ms = float('nan')
651+
max_ms = float('nan')
652+
mean = float('nan')
653+
cv = float('nan')
676654

677655
if MODE == 'bwd':
678656
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
@@ -711,8 +689,6 @@ def xetla_bwd_fn():
711689
)
712690

713691
elif provider == 'cutlass':
714-
cutlass_fn = None
715-
716692
if MODE == 'fwd':
717693
name = 'attention'
718694
func = getattr(cutlass_kernel, name)
@@ -723,17 +699,15 @@ def cutlass_fwd_fn():
723699
return out
724700

725701
benchmark_suite.assert_close(cutlass_fwd_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='cutlass to torch')
726-
cutlass_fn = cutlass_fwd_fn
727702

728703
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
729-
cutlass_fn,
704+
cutlass_fwd_fn,
730705
n_warmup=10,
731706
n_repeat=10,
732707
quantiles=quantiles,
733708
)
734709

735710
else:
736-
cutlass_fn = None
737711
min_ms = float('nan')
738712
max_ms = float('nan')
739713
mean = float('nan')
@@ -757,7 +731,5 @@ def cutlass_fwd_fn():
757731
if __name__ == '__main__':
758732
_benchmark = get_benchmark(
759733
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
760-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
761-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
762734
)
763735
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,16 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
141141
def get_benchmark(
142142
providers_filter: Optional[list[str]] = None,
143143
fa_kernel_mode='fwd',
144-
xetla_assert_result=False,
145-
xetla_warn_mismatch=False,
146144
):
147145
return flash_attention_benchmark.get_benchmark(
148146
providers_filter=providers_filter,
149147
fa_kernel_mode=fa_kernel_mode,
150148
attn_fwd=_attn_fwd_with_tensor_desc,
151-
xetla_assert_result=xetla_assert_result,
152-
xetla_warn_mismatch=xetla_warn_mismatch,
153149
)
154150

155151

156152
if __name__ == '__main__':
157153
_benchmark = get_benchmark(
158154
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
159-
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
160-
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '0') == '1'),
161155
)
162156
_benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)