diff --git a/tritonbench/operators/blackwell_attentions/operator.py b/tritonbench/operators/blackwell_attentions/operator.py index b3c43cee..2db990bf 100644 --- a/tritonbench/operators/blackwell_attentions/operator.py +++ b/tritonbench/operators/blackwell_attentions/operator.py @@ -44,6 +44,14 @@ except (ImportError, IOError, AttributeError): HAS_FLASH_V2 = False +# [Optional] CuTe +try: + import flash_attn.cute.interface as facute + + HAS_FLASH_CUTE = True +except (ImportError, IOError, AttributeError): + HAS_FLASH_CUTE = False + # [Optional] xformers backend try: import xformers # @manual=//fair/xformers:xformers @@ -266,6 +274,18 @@ def sdpa_flash_attention(q, k, v): v, ) + @register_benchmark( + enabled=(IS_B200 and HAS_FLASH_CUTE), label=f"cutedsl-blackwell", fwd_only=True + ) + def cutedsl_blackwell( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> Callable: + # [B, H, S, D] -> [B, S, H, D] + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + return lambda: facute.flash_attn_func(q, k, v, self.sm_scale, self.causal) + @register_benchmark() def flex_attention(self, q, k, v): from torch.nn.attention.flex_attention import create_block_mask, flex_attention