Skip to content

support flash-attn at torch backend #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _compute_attention(
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
# Only pass soft cap if needed as not all keras versions support.
if self.logit_soft_cap:
if self.logit_soft_cap is not None:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}
Expand Down
73 changes: 31 additions & 42 deletions keras_hub/src/models/mixtral/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self._num_query_heads = num_query_heads
self._num_key_value_heads = num_key_value_heads
self._sliding_window = sliding_window
self._dropout = dropout
self.num_query_heads = num_query_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason behind the renaming?

Copy link
Contributor Author

@pass-lin pass-lin May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason behind the renaming?

https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/mixtral/mixtral_attention.py
I'm just synchronizing it to the current repository here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, can you rebase your branch with master so that these dont show up as new changes

self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout

self._num_key_value_groups = num_query_heads // num_key_value_heads
self._rope_max_wavelength = rope_max_wavelength
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self._rope_scaling_factor = rope_scaling_factor
self.rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
Expand All @@ -51,12 +51,12 @@ def build(self, inputs_shape):
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self._num_query_heads
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self._num_query_heads, self._head_dim),
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
Expand All @@ -67,7 +67,7 @@ def build(self, inputs_shape):
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
Expand All @@ -80,7 +80,7 @@ def build(self, inputs_shape):
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
Expand All @@ -89,31 +89,31 @@ def build(self, inputs_shape):
)
self.value_dense.build(inputs_shape)

self._softmax = keras.layers.Softmax(
self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

self._dropout_layer = keras.layers.Dropout(
rate=self._dropout,
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self._output_dense = keras.layers.EinsumDense(
self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self._output_dense.build(
(None, None, self._num_query_heads, self._head_dim)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self._rope_max_wavelength,
scaling_factor=self._rope_scaling_factor,
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

Expand Down Expand Up @@ -168,39 +168,34 @@ def _compute_key_value(x):

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self._dropout_layer(
attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self._output_dense(attention_output)
attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
)
return self._softmax(attention_scores)
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(attention_scores)

def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
# GPU never supports softcap in the fused op.
if self.logit_soft_cap is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to return false in JAX backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to return false in JAX backend.

mixtral never use self.logit_soft_cap? so I can not get your mean.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! okay

return False
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
Expand All @@ -215,18 +210,12 @@ def _compute_attention(self, query, key, value, attention_mask=None):
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")

if self.logit_soft_cap:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}

attention_output = ops.dot_product_attention(
query,
key,
value,
mask=attention_mask,
scale=self._inv_norm_factor,
**kwargs,
)
return attention_output

Expand All @@ -249,15 +238,15 @@ def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self._num_query_heads,
"num_key_value_heads": self._num_key_value_heads,
"rope_max_wavelength": self._rope_max_wavelength,
"rope_scaling_factor": self._rope_scaling_factor,
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self._sliding_window,
"dropout": self._dropout,
"sliding_window": self.sliding_window,
"dropout": self.dropout,
}
)
return config
1 change: 1 addition & 0 deletions keras_hub/src/models/qwen_moe/qwen_moe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.rope_scaling_factor = rope_scaling_factor
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.logit_soft_cap = None

def build(self, inputs_shape):
# Einsum variables:
Expand Down
17 changes: 17 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def fused_attention_op_available():
)
return False
return True
elif (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101
in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

Copy link
Contributor Author

@pass-lin pass-lin May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101 in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

image
These are models that reference the fused_attention_op_available() function.
Here are the test results of A100.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch.
I propose, can you design tests on models like qwen and llama that are more suitable for flash-attn?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about it, I'll have to look into it

Copy link
Contributor Author

@pass-lin pass-lin May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

@pctablet505 @divyashreepathihalli
I can make sure this test is wrong, because it is testing gemma2, and gemm2 does not support flash-attn.

Copy link
Collaborator

@pctablet505 pctablet505 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin
I just verified that Gemma2 and Gemma3 can't support Flash_attention on A100 GPU.
Gemma3 can use flash attention on TPU or GPUs with cuda compute capability >=9.0 that is H series or latter. For example H100

#21333

hasattr(keras.config, "is_flash_attention_enabled")
and keras.config.backend() == "torch"
):
try:
from torch.backends.cuda import SDPAParams as SDPAParams
from torch.backends.cuda import (
can_use_flash_attention as can_use_flash_attention,
)
except ImportError:
logging.warning(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False
return True
else:
return False

Expand Down
Loading