-
Notifications
You must be signed in to change notification settings - Fork 287
support flash-attn at torch backend #2257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 11 commits
bcc0f22
faf8ffb
6bba5ae
0f960b8
b4dcc7f
72f4260
6ce366d
16c4541
78f2c06
52336ac
edbee6f
5c7f11f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,19 +27,19 @@ def __init__( | |
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self._num_query_heads = num_query_heads | ||
self._num_key_value_heads = num_key_value_heads | ||
self._sliding_window = sliding_window | ||
self._dropout = dropout | ||
self.num_query_heads = num_query_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.sliding_window = sliding_window | ||
self.dropout = dropout | ||
|
||
self._num_key_value_groups = num_query_heads // num_key_value_heads | ||
self._rope_max_wavelength = rope_max_wavelength | ||
self.num_key_value_groups = num_query_heads // num_key_value_heads | ||
self.rope_max_wavelength = rope_max_wavelength | ||
|
||
self._kernel_initializer = keras.initializers.get( | ||
clone_initializer(kernel_initializer) | ||
) | ||
|
||
self._rope_scaling_factor = rope_scaling_factor | ||
self.rope_scaling_factor = rope_scaling_factor | ||
|
||
def build(self, inputs_shape): | ||
# Einsum variables: | ||
|
@@ -51,12 +51,12 @@ def build(self, inputs_shape): | |
# v = num key/value heads | ||
# h = head dim | ||
self._hidden_dim = inputs_shape[-1] | ||
self._head_dim = self._hidden_dim // self._num_query_heads | ||
self._head_dim = self._hidden_dim // self.num_query_heads | ||
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) | ||
|
||
self.query_dense = keras.layers.EinsumDense( | ||
equation="bqm,muh->bquh", | ||
output_shape=(None, self._num_query_heads, self._head_dim), | ||
output_shape=(None, self.num_query_heads, self._head_dim), | ||
kernel_initializer=self._kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="query", | ||
|
@@ -67,7 +67,7 @@ def build(self, inputs_shape): | |
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self._num_key_value_heads, | ||
self.num_key_value_heads, | ||
self._head_dim, | ||
), | ||
kernel_initializer=self._kernel_initializer, | ||
|
@@ -80,7 +80,7 @@ def build(self, inputs_shape): | |
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self._num_key_value_heads, | ||
self.num_key_value_heads, | ||
self._head_dim, | ||
), | ||
kernel_initializer=self._kernel_initializer, | ||
|
@@ -89,31 +89,31 @@ def build(self, inputs_shape): | |
) | ||
self.value_dense.build(inputs_shape) | ||
|
||
self._softmax = keras.layers.Softmax( | ||
self.softmax = keras.layers.Softmax( | ||
axis=-1, | ||
dtype="float32", | ||
name="attention_softmax", | ||
) | ||
|
||
self._dropout_layer = keras.layers.Dropout( | ||
rate=self._dropout, | ||
self.dropout_layer = keras.layers.Dropout( | ||
rate=self.dropout, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._output_dense = keras.layers.EinsumDense( | ||
self.output_dense = keras.layers.EinsumDense( | ||
equation="bquh,uhm->bqm", | ||
output_shape=(None, self._hidden_dim), | ||
kernel_initializer=self._kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="attention_output", | ||
) | ||
self._output_dense.build( | ||
(None, None, self._num_query_heads, self._head_dim) | ||
self.output_dense.build( | ||
(None, None, self.num_query_heads, self._head_dim) | ||
) | ||
|
||
self.rotary_embedding_layer = RotaryEmbedding( | ||
max_wavelength=self._rope_max_wavelength, | ||
scaling_factor=self._rope_scaling_factor, | ||
max_wavelength=self.rope_max_wavelength, | ||
scaling_factor=self.rope_scaling_factor, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
|
@@ -168,39 +168,34 @@ def _compute_key_value(x): | |
|
||
# [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
# -> [batch_shape, seq_len, num_heads, head_dim] | ||
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) | ||
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
|
||
attention_output = self._compute_attention( | ||
query, key, value, attention_mask | ||
) | ||
|
||
attention_output = self._dropout_layer( | ||
attention_output = self.dropout_layer( | ||
attention_output, training=training | ||
) | ||
|
||
attention_output = self._output_dense(attention_output) | ||
attention_output = self.output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output | ||
|
||
def _masked_softmax(self, attention_scores, attention_mask=None): | ||
if attention_mask is not None: | ||
return self._softmax( | ||
attention_scores, attention_mask[:, None, :, :] | ||
) | ||
return self._softmax(attention_scores) | ||
return self.softmax(attention_scores, attention_mask[:, None, :, :]) | ||
return self.softmax(attention_scores) | ||
|
||
def _use_fused_attention_op(self): | ||
if not fused_attention_op_available(): | ||
return False | ||
if self.dropout > 0.0: | ||
return False | ||
if running_on_gpu(): | ||
# GPU never supports softcap in the fused op. | ||
if self.logit_soft_cap is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs to return false in JAX backend. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
mixtral never use self.logit_soft_cap? so I can not get your mean. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see! okay |
||
return False | ||
return gpu_supports_fused_attention_op() | ||
elif running_on_tpu(): | ||
# TPU supports softcap with on keras >= 3.10. | ||
|
@@ -215,18 +210,12 @@ def _compute_attention(self, query, key, value, attention_mask=None): | |
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_mask = ops.cast(attention_mask, dtype="bool") | ||
|
||
if self.logit_soft_cap: | ||
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} | ||
else: | ||
kwargs = {} | ||
|
||
attention_output = ops.dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
mask=attention_mask, | ||
scale=self._inv_norm_factor, | ||
**kwargs, | ||
) | ||
return attention_output | ||
|
||
|
@@ -249,15 +238,15 @@ def get_config(self): | |
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_query_heads": self._num_query_heads, | ||
"num_key_value_heads": self._num_key_value_heads, | ||
"rope_max_wavelength": self._rope_max_wavelength, | ||
"rope_scaling_factor": self._rope_scaling_factor, | ||
"num_query_heads": self.num_query_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"rope_max_wavelength": self.rope_max_wavelength, | ||
"rope_scaling_factor": self.rope_scaling_factor, | ||
"kernel_initializer": keras.initializers.serialize( | ||
self._kernel_initializer | ||
), | ||
"sliding_window": self._sliding_window, | ||
"dropout": self._dropout, | ||
"sliding_window": self.sliding_window, | ||
"dropout": self.dropout, | ||
} | ||
) | ||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,6 +71,23 @@ def fused_attention_op_available(): | |
) | ||
return False | ||
return True | ||
elif ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks good! Can you please enable this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pctablet505 - have you tested this? can you please take a look? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about it, I'll have to look into it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@pctablet505 @divyashreepathihalli There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
hasattr(keras.config, "is_flash_attention_enabled") | ||
and keras.config.backend() == "torch" | ||
): | ||
try: | ||
from torch.backends.cuda import SDPAParams as SDPAParams | ||
from torch.backends.cuda import ( | ||
can_use_flash_attention as can_use_flash_attention, | ||
) | ||
except ImportError: | ||
logging.warning( | ||
"Flash attention is not supported in your current PyTorch " | ||
"version. Please update it by following the official guide: " | ||
"https://pytorch.org/get-started/locally/" | ||
) | ||
return False | ||
return True | ||
else: | ||
return False | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason behind the renaming?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/mixtral/mixtral_attention.py
I'm just synchronizing it to the current repository here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, can you rebase your branch with master so that these dont show up as new changes