Skip to content

Re-apply Fixed issue with dot_product_attention when using TPU. #21254 after addressing cuDNN/FlashAttention API updates #21333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 10, 2025
Merged
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
04cd682
Update nn.py
pctablet505 May 6, 2025
1a74465
Update nn.py
pctablet505 May 6, 2025
c11eb81
Update nn.py
pctablet505 May 6, 2025
c81e18c
Update nn.py
pctablet505 May 6, 2025
d938e20
Update nn.py
pctablet505 May 7, 2025
f60811e
Update nn.py
pctablet505 May 7, 2025
b3ae323
Merge branch 'master' of https://github.com/pctablet505/keras
pctablet505 May 12, 2025
28eeb24
Update random_grayscale.py
pctablet505 May 12, 2025
de81e5b
Update keras/src/layers/preprocessing/image_preprocessing/random_gray…
pctablet505 May 12, 2025
66661ac
Update random_grayscale_test.py
pctablet505 May 12, 2025
c37f2b5
code reformat
pctablet505 May 13, 2025
498dece
Update random_grayscale_test.py
pctablet505 May 13, 2025
b0b5f63
Merge branch 'master' of https://github.com/pctablet505/keras
pctablet505 May 21, 2025
653f5b1
changed compute_output_spec
pctablet505 May 21, 2025
e681e4c
Merge branch 'keras-team:master' into master
pctablet505 May 21, 2025
27ad80b
Update random_grayscale.py
pctablet505 May 26, 2025
50f6292
Merge branch 'master' of https://github.com/pctablet505/keras
pctablet505 May 29, 2025
579cc11
Reapply "Fixed issue with dot_product_attention when using TPU. (#21…
pctablet505 May 29, 2025
7a0c547
Improve error handling in _can_use_flash_attention for better debugging
pctablet505 May 29, 2025
f7a2290
Revert "Improve error handling in _can_use_flash_attention for better…
pctablet505 May 29, 2025
8bae892
Fix JAX API compatibility and improve error handling in `_can_use_fla…
pctablet505 May 29, 2025
ee196cd
Updated `dot_product_attention`
pctablet505 May 29, 2025
40583c8
Update nn.py
pctablet505 Jun 7, 2025
7c918ba
Update nn.py
pctablet505 Jun 7, 2025
a927e7e
Merge branch 'keras-team:master' into master
pctablet505 Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 152 additions & 38 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
q_seqlen=None,
kv_seqlen=None,
layout=_normalize_layout("BTNH"),
q_offsets=None,
kv_offsets=None,
)
check_is_flash_attention(
query,
Expand All @@ -1072,9 +1074,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
is_training=False,
)
return True
except:
except Exception as e:
if raise_error:
raise
raise e
return False


Expand Down Expand Up @@ -1126,16 +1128,17 @@ def wrap_flash_attention(
decoder_segment_ids,
custom_mask=None,
attn_logits_soft_cap=None,
head_shards=1,
q_seq_shards=1,
):
if decoder_segment_ids is not None:
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
"Sharding along sequence dimension not allowed in tpu kernel "
"attention"
"Sharding along sequence dimension not allowed"
" in TPU kernel attention"
)

if custom_mask is not None:
mask = splash_attention_mask.NumpyMask(array=custom_mask)

else:
mask = splash_attention_mask.CausalMask(
shape=(query.shape[2], query.shape[2])
Expand All @@ -1147,8 +1150,8 @@ def wrap_flash_attention(
)
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
attn_logits_soft_cap=attn_logits_soft_cap,
)

Expand All @@ -1168,6 +1171,38 @@ def dot_product_attention(
flash_attention=None,
attn_logits_soft_cap=None,
):
"""Computes dot-product attention given query, key, and value.

This is the core computation of attention that is used in transformers.
For TPU platforms, flash attention optimizations are automatically applied
when possible, and sharding parameters are inferred from the layout map
in the current distribution context.

Args:
query: Queries with shape `[batch, time, heads,
depth_k]`.
key: Keys with shape `[batch, time, heads,
depth_k]`.
value: Values with shape `[batch, time, heads,
depth_v]`.
bias: Optional bias with shape broadcastable to
`[batch, heads, dest_time, source_time]`.
mask: Optional mask with shape broadcastable to
`[batch, heads, dest_time, source_time]`.
scale: Float. Optional scale that is applied to the attention
computation.
is_causal: Boolean. Specifying whether causal masking is applied.
flash_attention: Boolean. Whether to use flash attention optimization
for increased performance. Default to None, which means it will
be auto-determined based on the platform, input shapes and
compatibility.
attn_logits_soft_cap: Float. Optional float to softly cap attention
logits to avoid numerical stability issues. Applied as:
`logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.

Returns:
JAX Array of shape `[batch, time, heads, depth_v]`.
"""
query = convert_to_tensor(query)
key = convert_to_tensor(key)
value = convert_to_tensor(value)
Expand All @@ -1177,47 +1212,123 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)

# Check platform
platform = jax.devices()[0].platform
is_tpu = platform == "tpu"

# Determine flash attention compatibility
if flash_attention is None:
flash_attention = _can_use_flash_attention(query, key, value, bias)
elif flash_attention is True:
# Use `raise_error=True` to provide more details if the inputs failed to
# use flash attention
_can_use_flash_attention(query, key, value, bias, raise_error=True)

if jax.devices()[0].platform == "tpu":
# Transpose to ('batch', 'heads', 'length', 'kv')
query = jnp.transpose(query, axes=(0, 2, 1, 3))
key = jnp.transpose(key, axes=(0, 2, 1, 3))
value = jnp.transpose(value, axes=(0, 2, 1, 3))
B, H, S, KV = query.shape

segment_ids = jnp.ones([B, S])
# {token_ids, padding_mask, segment_ids} enable packing
out = wrap_flash_attention(
query,
key,
value,
decoder_segment_ids=splash_attention_kernel.SegmentIds(
segment_ids, segment_ids
),
custom_mask=mask,
attn_logits_soft_cap=attn_logits_soft_cap,
# TPU-specific flash attention path
if is_tpu and flash_attention:
# Get sharding parameters from distribution context
try:
from keras.src.distribution.distribution_lib import ModelParallel
from keras.src.distribution.distribution_lib import (
distribution as get_dist,
)

# Get current distribution if available
dist = get_dist()
if dist and isinstance(dist, ModelParallel):
mesh = dist.device_mesh
if "model" in mesh.axis_names:
model_dim_index = mesh.axis_names.index("model")
# Set head_shards based on the model dimension of the mesh
head_shards = mesh.shape[model_dim_index]
# Typically keep q_seq_shards=1 for best performance
q_seq_shards = 1
except (ImportError, ValueError, AttributeError):
# Use default values if detection fails
head_shards = 1
q_seq_shards = 1
# Transpose to ('batch', 'heads', 'length', 'head_dim')
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))

bs, num_heads, q_len, head_dim = query_tpu_layout.shape

# Apply scale to query if provided
if scale is not None:
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))

# Create segment IDs for Splash Attention (for packing/batching)
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
decoder_segment_ids = splash_attention_kernel.SegmentIds(
q=segment_ids, kv=segment_ids
)
out = jnp.transpose(out, axes=(0, 2, 1, 3))
return out

# `dot_product_attention` is only available in jax>=0.4.31
# Process mask for Splash Attention
custom_mask = None
if mask is not None:
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask

if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
custom_mask = mask_bool[0]
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
custom_mask = mask_bool[0, 0]

if is_causal and custom_mask is not None:
causal_mask = jnp.tril(
jnp.ones((q_len, q_len), dtype=jnp.bool_)
)
custom_mask = jnp.logical_and(custom_mask, causal_mask)

if custom_mask is None and is_causal:
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))

try:
output = wrap_flash_attention(
query_tpu_layout,
key_tpu_layout,
value_tpu_layout,
decoder_segment_ids=decoder_segment_ids,
custom_mask=custom_mask,
attn_logits_soft_cap=attn_logits_soft_cap,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
flash_attention = False

# JAX native dot_product_attention for GPU or fallback for TPU
if hasattr(jax.nn, "dot_product_attention"):
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="cudnn" if flash_attention else "xla",
)
try:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="cudnn" if flash_attention else "xla",
)
except Exception:
# If flash attention fails, fall back to XLA implementation
if flash_attention:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="xla",
)
raise

if flash_attention:
raise RuntimeError(
Expand All @@ -1228,6 +1339,9 @@ def dot_product_attention(
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
# Not support `query_seq_lengths` and `key_value_seq_lengths` args

# Fallback to custom XLA implementation
# This is the reference implementation from jax.nn.dot_product_attention
output_shape = query.shape
_, _, K, H = key.shape
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale
Expand Down