|
| 1 | + |
| 2 | +--- |
| 3 | +title: "Flash Attention Benchmark" |
| 4 | +author: "uvnote" |
| 5 | +theme: "dark" |
| 6 | +syntax_theme: "monokai" |
| 7 | +show_line_numbers: true |
| 8 | +collapse_code: false |
| 9 | +custom_css: | |
| 10 | + #output-setup { |
| 11 | + overflow-x: auto; |
| 12 | + } |
| 13 | + .cell-output { |
| 14 | + overflow: scroll; |
| 15 | + } |
| 16 | + .cell-stdout { |
| 17 | + width: max-content; |
| 18 | + overflow: scroll; |
| 19 | + } |
| 20 | + .cell-stderr { |
| 21 | + width: max-content; |
| 22 | + overflow: scroll; |
| 23 | + max-height: 300px; |
| 24 | + } |
| 25 | +--- |
| 26 | + |
| 27 | + |
| 28 | +```python id=benchmark |
| 29 | +# /// script |
| 30 | +# dependencies = [ |
| 31 | +# "numpy", |
| 32 | +# "torch", |
| 33 | +# "kernels", |
| 34 | +# "pandas", |
| 35 | +# "matplotlib" |
| 36 | +# ] |
| 37 | +# /// |
| 38 | +# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths |
| 39 | + |
| 40 | +import functools |
| 41 | +import os |
| 42 | +import pathlib |
| 43 | + |
| 44 | +import matplotlib.pyplot as plt |
| 45 | +import torch |
| 46 | +import torch._dynamo.config |
| 47 | +import triton |
| 48 | +import triton.language as tl |
| 49 | + |
| 50 | +try: |
| 51 | + from flash_attn import flash_attn_func |
| 52 | +except: |
| 53 | + flash_attn_func = None |
| 54 | + print("Flash Attention 2 not found.") |
| 55 | + |
| 56 | +try: |
| 57 | + from flash_attn_interface import flash_attn_func as flash_attn_3_func |
| 58 | +except: |
| 59 | + flash_attn_3_func = None |
| 60 | + print("Flash Attention 3 not found.") |
| 61 | + |
| 62 | +try: |
| 63 | + from kernels import get_kernel |
| 64 | + hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") |
| 65 | + hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3") |
| 66 | +except: |
| 67 | + hf_kernels_flash_attn = None |
| 68 | + hf_kernels_flash_attn_3 = None |
| 69 | + print("HF Kernels not found.") |
| 70 | + |
| 71 | +try: |
| 72 | + from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 |
| 73 | +except: |
| 74 | + sageattn_qk_int8_pv_fp16_cuda = None |
| 75 | + sageattn_qk_int8_pv_fp16_triton = None |
| 76 | + sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
| 77 | + print("SageAttention not found.") |
| 78 | + |
| 79 | +try: |
| 80 | + from transformer_engine.pytorch.attention import DotProductAttention |
| 81 | +except: |
| 82 | + DotProductAttention = None |
| 83 | + print("Transformer Engine not found.") |
| 84 | + |
| 85 | +try: |
| 86 | + import xformers.ops as xops |
| 87 | +except: |
| 88 | + xops = None |
| 89 | + print("xFormers not found.") |
| 90 | + |
| 91 | + |
| 92 | +plt.rcParams.update({ |
| 93 | + "figure.figsize": (12, 10), |
| 94 | + "figure.dpi": 120, |
| 95 | + "font.size": 10, |
| 96 | + "axes.titlesize": 12, |
| 97 | + "axes.labelsize": 14, |
| 98 | + "xtick.labelsize": 10, |
| 99 | + "ytick.labelsize": 10, |
| 100 | + "legend.fontsize": 8, |
| 101 | + "axes.grid": True, |
| 102 | + "grid.alpha": 0.3, |
| 103 | + "grid.linestyle": "--", |
| 104 | + "lines.linewidth": 2.0, |
| 105 | + "lines.markersize": 6, |
| 106 | + "legend.frameon": True, |
| 107 | + "legend.framealpha": 0.9, |
| 108 | + "legend.loc": "best", |
| 109 | + "axes.spines.top": False, |
| 110 | + "axes.spines.right": False, |
| 111 | +}) |
| 112 | + |
| 113 | + |
| 114 | +# We want to compare the best compiled version for each specific shape (dynamic=False) |
| 115 | +torch._dynamo.config.cache_size_limit = 10000 |
| 116 | + |
| 117 | +# We need to suppress_errors for FA3 to work. It makes it run in eager mode. |
| 118 | +# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome! |
| 119 | +torch._dynamo.config.suppress_errors = True |
| 120 | + |
| 121 | +output_dir = pathlib.Path("dump_attention_benchmark") |
| 122 | +output_dir.mkdir(parents=True, exist_ok=True) |
| 123 | + |
| 124 | +batch_size = 1 |
| 125 | +num_attention_heads = 24 |
| 126 | +attention_head_dim = 128 |
| 127 | +image_sequence_length = 4096 # 1024x1024px |
| 128 | +text_sequence_lengths = [128, 256, 320, 384, 448, 512] |
| 129 | +sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] |
| 130 | + |
| 131 | + |
| 132 | +def _attention_torch(query, key, value, *, backend): |
| 133 | + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) |
| 134 | + with torch.nn.attention.sdpa_kernel(backend): |
| 135 | + out = torch.nn.functional.scaled_dot_product_attention(query, key, value) |
| 136 | + out = out.transpose(1, 2).contiguous() |
| 137 | + return out |
| 138 | + |
| 139 | + |
| 140 | +_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) |
| 141 | +def _attention_torch_compile_default(query, key, value, *, backend): |
| 142 | + return _compiled_attention_torch_default(query, key, value, backend=backend) |
| 143 | + |
| 144 | + |
| 145 | +_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) |
| 146 | +def _attention_torch_compile_max_autotune(query, key, value, *, backend): |
| 147 | + return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) |
| 148 | + |
| 149 | + |
| 150 | +def _attention_flash_attn_2(query, key, value): |
| 151 | + return flash_attn_func(query, key, value) |
| 152 | + |
| 153 | + |
| 154 | +_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) |
| 155 | +def _attention_flash_attn_2_compile_default(query, key, value): |
| 156 | + return _compiled_flash_attn_2_default(query, key, value) |
| 157 | + |
| 158 | + |
| 159 | +_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) |
| 160 | +def _attention_flash_attn_2_compile_max_autotune(query, key, value): |
| 161 | + return _compiled_flash_attn_2_max_autotune(query, key, value) |
| 162 | + |
| 163 | + |
| 164 | +# For fullgraph=True tracing to be compatible |
| 165 | +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") |
| 166 | +def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
| 167 | + out, lse = flash_attn_3_func(query, key, value) |
| 168 | + return out |
| 169 | + |
| 170 | + |
| 171 | +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") |
| 172 | +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: |
| 173 | + return torch.empty_like(query) |
| 174 | + |
| 175 | + |
| 176 | +def _attention_flash_attn_3(query, key, value): |
| 177 | + out = _wrapped_flash_attn_3(query, key, value) |
| 178 | + return out |
| 179 | + |
| 180 | + |
| 181 | +_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) |
| 182 | +def _attention_flash_attn_3_compile_default(query, key, value): |
| 183 | + return _compiled_flash_attn_3_default(query, key, value) |
| 184 | + |
| 185 | + |
| 186 | +_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) |
| 187 | +def _attention_flash_attn_3_compile_max_autotune(query, key, value): |
| 188 | + return _compiled_flash_attn_3_max_autotune(query, key, value) |
| 189 | + |
| 190 | + |
| 191 | +def _attention_hf_kernels_flash_attn(query, key, value): |
| 192 | + return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] |
| 193 | + |
| 194 | + |
| 195 | +def _attention_hf_kernels_flash_attn3(query, key, value): |
| 196 | + return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0] |
| 197 | + |
| 198 | + |
| 199 | +def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): |
| 200 | + return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") |
| 201 | + |
| 202 | + |
| 203 | +def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): |
| 204 | + return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") |
| 205 | + |
| 206 | + |
| 207 | +def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): |
| 208 | + return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") |
| 209 | + |
| 210 | + |
| 211 | +if DotProductAttention is not None: |
| 212 | + def set_te_backend(backend): |
| 213 | + # must be applied before first use of |
| 214 | + # transformer_engine.pytorch.attention |
| 215 | + os.environ["NVTE_FLASH_ATTN"] = '0' |
| 216 | + os.environ["NVTE_FUSED_ATTN"] = '0' |
| 217 | + os.environ["NVTE_UNFUSED_ATTN"] = '0' |
| 218 | + if backend == 'flash': |
| 219 | + os.environ["NVTE_FLASH_ATTN"] = '1' |
| 220 | + if backend == 'fused': |
| 221 | + os.environ["NVTE_FUSED_ATTN"] = '1' |
| 222 | + if backend == 'unfused': |
| 223 | + os.environ["NVTE_UNFUSED_ATTN"] = '1' |
| 224 | + |
| 225 | + set_te_backend("fused") |
| 226 | + te_attn_fn = DotProductAttention( |
| 227 | + num_attention_heads=num_attention_heads, |
| 228 | + kv_channels=attention_head_dim, |
| 229 | + qkv_format="bshd", |
| 230 | + attn_mask_type="no_mask", |
| 231 | + ) |
| 232 | +else: |
| 233 | + def te_attn_fn(query, key, value): |
| 234 | + raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") |
| 235 | + |
| 236 | +def _attention_te(query, key, value): |
| 237 | + out = te_attn_fn(query, key, value) |
| 238 | + out = out.unflatten(2, (num_attention_heads, attention_head_dim)) |
| 239 | + return out |
| 240 | + |
| 241 | + |
| 242 | +# Cannot fullgraph compile TE |
| 243 | +_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) |
| 244 | +def _attention_te_compile_default(query, key, value): |
| 245 | + return _compiled_te_attn_fn_default(query, key, value) |
| 246 | + |
| 247 | + |
| 248 | +# Cannot fullgraph compile TE |
| 249 | +_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) |
| 250 | +def _attention_te_compile_max_autotune(query, key, value): |
| 251 | + return _compiled_te_attn_fn_max_autotune(query, key, value) |
| 252 | + |
| 253 | + |
| 254 | +def _attention_xformers(query, key, value): |
| 255 | + return xops.memory_efficient_attention(query, key, value) |
| 256 | + |
| 257 | + |
| 258 | +_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) |
| 259 | +def _attention_xformers_compile_default(query, key, value): |
| 260 | + return _compiled_xformers_default(query, key, value) |
| 261 | + |
| 262 | + |
| 263 | +_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) |
| 264 | +def _attention_xformers_compile_max_autotune(query, key, value): |
| 265 | + return _compiled_xformers_max_autotune(query, key, value) |
| 266 | + |
| 267 | + |
| 268 | +attention_ops = {} |
| 269 | +attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
| 270 | +attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
| 271 | +attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) |
| 272 | +attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
| 273 | +attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
| 274 | +attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) |
| 275 | +if hf_kernels_flash_attn is not None: |
| 276 | + attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn |
| 277 | + attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3 |
| 278 | +if flash_attn_func is not None: |
| 279 | + attention_ops["flash_attn_2"] = _attention_flash_attn_2 |
| 280 | + attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default |
| 281 | + attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune |
| 282 | +if flash_attn_3_func is not None: |
| 283 | + attention_ops["flash_attn_3"] = _attention_flash_attn_3 |
| 284 | + attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default |
| 285 | + attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune |
| 286 | +if sageattn_qk_int8_pv_fp16_cuda is not None: |
| 287 | + attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda |
| 288 | + attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton |
| 289 | + if torch.cuda.get_device_capability()[0] >= 9: |
| 290 | + attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 |
| 291 | +if DotProductAttention is not None: |
| 292 | + attention_ops["te_fused"] = _attention_te |
| 293 | + attention_ops["te_fused_compile_d"] = _attention_te_compile_default |
| 294 | + attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune |
| 295 | +if xops is not None: |
| 296 | + attention_ops["xformers"] = _attention_xformers |
| 297 | + attention_ops["xformers_compile_d"] = _attention_xformers_compile_default |
| 298 | + attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune |
| 299 | + |
| 300 | + |
| 301 | +def get_color_and_linestyle(n: int) -> tuple[str, str]: |
| 302 | + colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] |
| 303 | + line_styles = ["-", ":", "-.", "--"] |
| 304 | + if n > len(colors) * len(line_styles): |
| 305 | + raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") |
| 306 | + styles = [] |
| 307 | + for i in range(n): |
| 308 | + color = colors[i % len(colors)] |
| 309 | + linestyle = line_styles[i // len(colors)] |
| 310 | + styles.append((color, linestyle)) |
| 311 | + return styles |
| 312 | + |
| 313 | + |
| 314 | +def correctness(): |
| 315 | + for seq_len in sequence_lengths: |
| 316 | + shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) |
| 317 | + print(f"\n\n===== Testing shape: {shape} =====") |
| 318 | + |
| 319 | + query = torch.randn(shape, device="cuda", dtype=torch.float32) |
| 320 | + key = torch.randn(shape, device="cuda", dtype=torch.float32) |
| 321 | + value = torch.randn(shape, device="cuda", dtype=torch.float32) |
| 322 | + |
| 323 | + golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) |
| 324 | + query, key, value = (x.bfloat16() for x in (query, key, value)) |
| 325 | + |
| 326 | + for name, fn in attention_ops.items(): |
| 327 | + out = fn(query, key, value) |
| 328 | + absdiff = (out - golden_truth).abs() |
| 329 | + absmax = torch.max(absdiff) |
| 330 | + mae = torch.mean(absdiff) |
| 331 | + mse = torch.mean((golden_truth - out) ** 2) |
| 332 | + print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") |
| 333 | + |
| 334 | + |
| 335 | +@triton.testing.perf_report( |
| 336 | + triton.testing.Benchmark( |
| 337 | + x_names=["seq_len"], |
| 338 | + x_vals=sequence_lengths, |
| 339 | + x_log=False, |
| 340 | + line_arg="provider", |
| 341 | + line_vals=list(attention_ops.keys()), |
| 342 | + line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], |
| 343 | + ylabel="Time (ms)", |
| 344 | + styles=get_color_and_linestyle(len(attention_ops)), |
| 345 | + plot_name="Attention Benchmark", |
| 346 | + args={}, |
| 347 | + ) |
| 348 | +) |
| 349 | +def benchmark_fn(seq_len: int, provider: str): |
| 350 | + torch.manual_seed(0) |
| 351 | + |
| 352 | + shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) |
| 353 | + query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
| 354 | + key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
| 355 | + value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) |
| 356 | + |
| 357 | + fn = attention_ops[provider] |
| 358 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 359 | + lambda: fn(query, key, value), |
| 360 | + warmup=3, |
| 361 | + rep=10, |
| 362 | + quantiles=[0.5, 0.2, 0.8], |
| 363 | + ) |
| 364 | + return ms, max_ms, min_ms |
| 365 | + |
| 366 | + |
| 367 | +with torch.inference_mode(): |
| 368 | + correctness() |
| 369 | + benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) |
| 370 | + |
| 371 | +``` |
0 commit comments