Skip to content

Commit a801ce1

Browse files
committed
feat: add flash attn bench
1 parent 9d825eb commit a801ce1

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed

flash_attn/benchmark.md

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
2+
---
3+
title: "Flash Attention Benchmark"
4+
author: "uvnote"
5+
theme: "dark"
6+
syntax_theme: "monokai"
7+
show_line_numbers: true
8+
collapse_code: false
9+
custom_css: |
10+
#output-setup {
11+
overflow-x: auto;
12+
}
13+
.cell-output {
14+
overflow: scroll;
15+
}
16+
.cell-stdout {
17+
width: max-content;
18+
overflow: scroll;
19+
}
20+
.cell-stderr {
21+
width: max-content;
22+
overflow: scroll;
23+
max-height: 300px;
24+
}
25+
---
26+
27+
28+
```python id=benchmark
29+
# /// script
30+
# dependencies = [
31+
# "numpy",
32+
# "torch",
33+
# "kernels",
34+
# "pandas",
35+
# "matplotlib"
36+
# ]
37+
# ///
38+
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
39+
40+
import functools
41+
import os
42+
import pathlib
43+
44+
import matplotlib.pyplot as plt
45+
import torch
46+
import torch._dynamo.config
47+
import triton
48+
import triton.language as tl
49+
50+
try:
51+
from flash_attn import flash_attn_func
52+
except:
53+
flash_attn_func = None
54+
print("Flash Attention 2 not found.")
55+
56+
try:
57+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
58+
except:
59+
flash_attn_3_func = None
60+
print("Flash Attention 3 not found.")
61+
62+
try:
63+
from kernels import get_kernel
64+
hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
65+
hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
66+
except:
67+
hf_kernels_flash_attn = None
68+
hf_kernels_flash_attn_3 = None
69+
print("HF Kernels not found.")
70+
71+
try:
72+
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
73+
except:
74+
sageattn_qk_int8_pv_fp16_cuda = None
75+
sageattn_qk_int8_pv_fp16_triton = None
76+
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
77+
print("SageAttention not found.")
78+
79+
try:
80+
from transformer_engine.pytorch.attention import DotProductAttention
81+
except:
82+
DotProductAttention = None
83+
print("Transformer Engine not found.")
84+
85+
try:
86+
import xformers.ops as xops
87+
except:
88+
xops = None
89+
print("xFormers not found.")
90+
91+
92+
plt.rcParams.update({
93+
"figure.figsize": (12, 10),
94+
"figure.dpi": 120,
95+
"font.size": 10,
96+
"axes.titlesize": 12,
97+
"axes.labelsize": 14,
98+
"xtick.labelsize": 10,
99+
"ytick.labelsize": 10,
100+
"legend.fontsize": 8,
101+
"axes.grid": True,
102+
"grid.alpha": 0.3,
103+
"grid.linestyle": "--",
104+
"lines.linewidth": 2.0,
105+
"lines.markersize": 6,
106+
"legend.frameon": True,
107+
"legend.framealpha": 0.9,
108+
"legend.loc": "best",
109+
"axes.spines.top": False,
110+
"axes.spines.right": False,
111+
})
112+
113+
114+
# We want to compare the best compiled version for each specific shape (dynamic=False)
115+
torch._dynamo.config.cache_size_limit = 10000
116+
117+
# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
118+
# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
119+
torch._dynamo.config.suppress_errors = True
120+
121+
output_dir = pathlib.Path("dump_attention_benchmark")
122+
output_dir.mkdir(parents=True, exist_ok=True)
123+
124+
batch_size = 1
125+
num_attention_heads = 24
126+
attention_head_dim = 128
127+
image_sequence_length = 4096 # 1024x1024px
128+
text_sequence_lengths = [128, 256, 320, 384, 448, 512]
129+
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
130+
131+
132+
def _attention_torch(query, key, value, *, backend):
133+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
134+
with torch.nn.attention.sdpa_kernel(backend):
135+
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
136+
out = out.transpose(1, 2).contiguous()
137+
return out
138+
139+
140+
_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
141+
def _attention_torch_compile_default(query, key, value, *, backend):
142+
return _compiled_attention_torch_default(query, key, value, backend=backend)
143+
144+
145+
_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
146+
def _attention_torch_compile_max_autotune(query, key, value, *, backend):
147+
return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
148+
149+
150+
def _attention_flash_attn_2(query, key, value):
151+
return flash_attn_func(query, key, value)
152+
153+
154+
_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
155+
def _attention_flash_attn_2_compile_default(query, key, value):
156+
return _compiled_flash_attn_2_default(query, key, value)
157+
158+
159+
_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
160+
def _attention_flash_attn_2_compile_max_autotune(query, key, value):
161+
return _compiled_flash_attn_2_max_autotune(query, key, value)
162+
163+
164+
# For fullgraph=True tracing to be compatible
165+
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
166+
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
167+
out, lse = flash_attn_3_func(query, key, value)
168+
return out
169+
170+
171+
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
172+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
173+
return torch.empty_like(query)
174+
175+
176+
def _attention_flash_attn_3(query, key, value):
177+
out = _wrapped_flash_attn_3(query, key, value)
178+
return out
179+
180+
181+
_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
182+
def _attention_flash_attn_3_compile_default(query, key, value):
183+
return _compiled_flash_attn_3_default(query, key, value)
184+
185+
186+
_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
187+
def _attention_flash_attn_3_compile_max_autotune(query, key, value):
188+
return _compiled_flash_attn_3_max_autotune(query, key, value)
189+
190+
191+
def _attention_hf_kernels_flash_attn(query, key, value):
192+
return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
193+
194+
195+
def _attention_hf_kernels_flash_attn3(query, key, value):
196+
return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
197+
198+
199+
def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
200+
return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
201+
202+
203+
def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
204+
return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
205+
206+
207+
def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
208+
return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
209+
210+
211+
if DotProductAttention is not None:
212+
def set_te_backend(backend):
213+
# must be applied before first use of
214+
# transformer_engine.pytorch.attention
215+
os.environ["NVTE_FLASH_ATTN"] = '0'
216+
os.environ["NVTE_FUSED_ATTN"] = '0'
217+
os.environ["NVTE_UNFUSED_ATTN"] = '0'
218+
if backend == 'flash':
219+
os.environ["NVTE_FLASH_ATTN"] = '1'
220+
if backend == 'fused':
221+
os.environ["NVTE_FUSED_ATTN"] = '1'
222+
if backend == 'unfused':
223+
os.environ["NVTE_UNFUSED_ATTN"] = '1'
224+
225+
set_te_backend("fused")
226+
te_attn_fn = DotProductAttention(
227+
num_attention_heads=num_attention_heads,
228+
kv_channels=attention_head_dim,
229+
qkv_format="bshd",
230+
attn_mask_type="no_mask",
231+
)
232+
else:
233+
def te_attn_fn(query, key, value):
234+
raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
235+
236+
def _attention_te(query, key, value):
237+
out = te_attn_fn(query, key, value)
238+
out = out.unflatten(2, (num_attention_heads, attention_head_dim))
239+
return out
240+
241+
242+
# Cannot fullgraph compile TE
243+
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
244+
def _attention_te_compile_default(query, key, value):
245+
return _compiled_te_attn_fn_default(query, key, value)
246+
247+
248+
# Cannot fullgraph compile TE
249+
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
250+
def _attention_te_compile_max_autotune(query, key, value):
251+
return _compiled_te_attn_fn_max_autotune(query, key, value)
252+
253+
254+
def _attention_xformers(query, key, value):
255+
return xops.memory_efficient_attention(query, key, value)
256+
257+
258+
_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
259+
def _attention_xformers_compile_default(query, key, value):
260+
return _compiled_xformers_default(query, key, value)
261+
262+
263+
_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
264+
def _attention_xformers_compile_max_autotune(query, key, value):
265+
return _compiled_xformers_max_autotune(query, key, value)
266+
267+
268+
attention_ops = {}
269+
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
270+
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
271+
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
272+
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
273+
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
274+
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
275+
if hf_kernels_flash_attn is not None:
276+
attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
277+
attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
278+
if flash_attn_func is not None:
279+
attention_ops["flash_attn_2"] = _attention_flash_attn_2
280+
attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
281+
attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
282+
if flash_attn_3_func is not None:
283+
attention_ops["flash_attn_3"] = _attention_flash_attn_3
284+
attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
285+
attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
286+
if sageattn_qk_int8_pv_fp16_cuda is not None:
287+
attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
288+
attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
289+
if torch.cuda.get_device_capability()[0] >= 9:
290+
attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
291+
if DotProductAttention is not None:
292+
attention_ops["te_fused"] = _attention_te
293+
attention_ops["te_fused_compile_d"] = _attention_te_compile_default
294+
attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
295+
if xops is not None:
296+
attention_ops["xformers"] = _attention_xformers
297+
attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
298+
attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
299+
300+
301+
def get_color_and_linestyle(n: int) -> tuple[str, str]:
302+
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
303+
line_styles = ["-", ":", "-.", "--"]
304+
if n > len(colors) * len(line_styles):
305+
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
306+
styles = []
307+
for i in range(n):
308+
color = colors[i % len(colors)]
309+
linestyle = line_styles[i // len(colors)]
310+
styles.append((color, linestyle))
311+
return styles
312+
313+
314+
def correctness():
315+
for seq_len in sequence_lengths:
316+
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
317+
print(f"\n\n===== Testing shape: {shape} =====")
318+
319+
query = torch.randn(shape, device="cuda", dtype=torch.float32)
320+
key = torch.randn(shape, device="cuda", dtype=torch.float32)
321+
value = torch.randn(shape, device="cuda", dtype=torch.float32)
322+
323+
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
324+
query, key, value = (x.bfloat16() for x in (query, key, value))
325+
326+
for name, fn in attention_ops.items():
327+
out = fn(query, key, value)
328+
absdiff = (out - golden_truth).abs()
329+
absmax = torch.max(absdiff)
330+
mae = torch.mean(absdiff)
331+
mse = torch.mean((golden_truth - out) ** 2)
332+
print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
333+
334+
335+
@triton.testing.perf_report(
336+
triton.testing.Benchmark(
337+
x_names=["seq_len"],
338+
x_vals=sequence_lengths,
339+
x_log=False,
340+
line_arg="provider",
341+
line_vals=list(attention_ops.keys()),
342+
line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
343+
ylabel="Time (ms)",
344+
styles=get_color_and_linestyle(len(attention_ops)),
345+
plot_name="Attention Benchmark",
346+
args={},
347+
)
348+
)
349+
def benchmark_fn(seq_len: int, provider: str):
350+
torch.manual_seed(0)
351+
352+
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
353+
query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
354+
key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
355+
value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
356+
357+
fn = attention_ops[provider]
358+
ms, min_ms, max_ms = triton.testing.do_bench(
359+
lambda: fn(query, key, value),
360+
warmup=3,
361+
rep=10,
362+
quantiles=[0.5, 0.2, 0.8],
363+
)
364+
return ms, max_ms, min_ms
365+
366+
367+
with torch.inference_mode():
368+
correctness()
369+
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
370+
371+
```

0 commit comments

Comments
 (0)