Skip to content

Commit 50a639e

Browse files
jdupratfacebook-github-bot
authored andcommitted
WIP: TritonBench for FA4 (#296)
Summary: Flash Attention 4 Work in Progress Allows running both Tri Dao's benchmarking script and TritonBench. Work as a collaboration between JLD and Sarunya. Reviewed By: devashishshankar Differential Revision: D78096306
1 parent 608f961 commit 50a639e

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@
4444
except (ImportError, IOError, AttributeError):
4545
HAS_FLASH_V2 = False
4646

47+
# [Optional] CuTe
48+
try:
49+
import flash_attn.cute.interface as facute
50+
HAS_FLASH_CUTE = True
51+
except (ImportError, IOError, AttributeError):
52+
HAS_FLASH_CUTE = False
53+
4754
# [Optional] xformers backend
4855
try:
4956
import xformers # @manual=//fair/xformers:xformers
@@ -266,6 +273,19 @@ def sdpa_flash_attention(q, k, v):
266273
v,
267274
)
268275

276+
@register_benchmark(enabled=(IS_B200 and HAS_FLASH_CUTE), label=f"cutedsl-blackwell", fwd_only=True)
277+
def cutedsl_blackwell(
278+
self,
279+
q: torch.Tensor,
280+
k: torch.Tensor,
281+
v: torch.Tensor) -> Callable:
282+
283+
# [B, H, S, D] -> [B, S, H, D]
284+
q = q.transpose(1, 2).contiguous()
285+
k = k.transpose(1, 2).contiguous()
286+
v = v.transpose(1, 2).contiguous()
287+
return lambda: facute.flash_attn_func(q, k, v, self.sm_scale, self.causal)
288+
269289
@register_benchmark()
270290
def flex_attention(self, q, k, v):
271291
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

0 commit comments

Comments
 (0)