Skip to content

Commit a4d90ab

Browse files
authored
support new fa3 api (#10661)
1 parent f705b6b commit a4d90ab

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@
5353
except:
5454
pass
5555

56-
try:
57-
from paddle.nn.functional import flash_attn_v3
58-
from paddle.nn.functional.flash_attention import flash_attention
59-
except:
60-
flash_attention = None
56+
from paddle.nn.functional.flash_attention import flash_attention
6157

6258

6359
from paddle import _C_ops
@@ -1063,9 +1059,24 @@ def forward(
10631059
)
10641060

10651061
elif FA_VERSION == 3:
1066-
attn_out, softmax_lse = flash_attn_v3(
1067-
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=True
1068-
)
1062+
attn_out, softmax_lse = _C_ops.flash_attn_v3(
1063+
query_states,
1064+
key_states,
1065+
value_states,
1066+
None, # q_v_
1067+
None, # q_descale_
1068+
None, # k_descale_
1069+
None, # v_descale_
1070+
softmax_scale,
1071+
True,
1072+
-1, # window_size_left
1073+
-1, # window_size_right
1074+
0.0, # softcap
1075+
1, # num_splits
1076+
False, # manual_set_pack_gqa
1077+
False, # pack_gqa_
1078+
0, # sm_margin
1079+
)
10691080
else:
10701081
assert False, f"invalid {FA_VERSION=}"
10711082

0 commit comments

Comments
 (0)