Skip to content

Commit e6251f1

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Enable flash_v3 backward (#2445)
Summary: Pull Request resolved: #2445 Reviewed By: xuzhao9 Differential Revision: D61924864 Pulled By: bertmaher fbshipit-source-id: 760036820c1196a921eaff4d99bf8647e25264ee
1 parent c0409aa commit e6251f1

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torchbenchmark/operators/flash_attention/operator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
try:
7474
torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
7575
with add_ld_library_path(torch_lib_path):
76-
import flashattn_hopper_cuda
76+
from flash_attn_interface import flash_attn_func as flash_attn_v3
7777
except (ImportError, IOError, AttributeError):
7878
HAS_FLASH_V3 = False
7979
pass
@@ -223,9 +223,7 @@ def flash_v3(
223223
q = q.transpose(1, 2).contiguous()
224224
k = k.transpose(1, 2).contiguous()
225225
v = v.transpose(1, 2).contiguous()
226-
fn = lambda: flashattn_hopper_cuda.fwd(
227-
q, k, v, None, self.sm_scale, self.causal
228-
)
226+
fn = lambda: flash_attn_v3(q, k, v, self.sm_scale, self.causal)
229227
return fn
230228

231229
@register_benchmark()

0 commit comments

Comments
 (0)