Skip to content

Re-apply Fixed issue with dot_product_attention when using TPU. #21254 after addressing cuDNN/FlashAttention API updates #21333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 10, 2025

Conversation

pctablet505
Copy link
Collaborator

@pctablet505 pctablet505 commented May 29, 2025

This PR reapplies the changes from #21254 (“Fixed issue with dot_product_attention when using TPU”), which was previously reverted in #21329 due to a test failure involving the gemma2 Flash Attention test on A100 GPUs.

Root Cause Analysis:

Conclusion:
The gemma2 and gemma3 models cannot use Flash Attention on GPU until JAX adds support for larger head dimensions. The original PR did not cause a regression; the limitation is due to upstream JAX/cuDNN constraints, not this code.

Changes in this PR:

  • Updates _can_use_flash_attention to use the correct signature for check_layout, matching the latest JAX API.
  • Adds error handling and clarifies diagnostic output for unsupported head dimension values.
  • Reapplies the original fix for dot_product_attention on TPU.

Note:
This PR does not re-enable Flash Attention for gemma2/gemma3 on A100 GPUs—they remain unsupported by JAX at this time. See the upstream JAX issue for future support.

pctablet505 and others added 22 commits May 6, 2025 10:25
Corrected indentation in doc string
Fixed issue with passing a single image without batch dimension.
…scale.py

Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com>
Test case for unbatched inputs
Testcase for checking both unbatched and batched single image inputs.
There was a bug, and it was causing cycle in graph.
removed the use of tree.map_structure
Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.

Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
  check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)

This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.

Relates to keras-hub PR keras-team#2257 and addresses flash attention debugging needs.
…sh_attention`

Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
  call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
  preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False

This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
@github-actions github-actions bot added the Gemma Gemma model specific issues label May 29, 2025
@codecov-commenter
Copy link

codecov-commenter commented May 29, 2025

Codecov Report

Attention: Patch coverage is 8.51064% with 43 lines in your changes missing coverage. Please review.

Project coverage is 82.67%. Comparing base (24f104e) to head (a927e7e).

Files with missing lines Patch % Lines
keras/src/backend/jax/nn.py 8.51% 42 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21333      +/-   ##
==========================================
- Coverage   82.72%   82.67%   -0.06%     
==========================================
  Files         565      565              
  Lines       54904    54941      +37     
  Branches     8520     8529       +9     
==========================================
+ Hits        45418    45421       +3     
- Misses       7399     7433      +34     
  Partials     2087     2087              
Flag Coverage Δ
keras 82.48% <8.51%> (-0.06%) ⬇️
keras-jax 63.51% <8.51%> (-0.04%) ⬇️
keras-numpy 58.66% <0.00%> (-0.04%) ⬇️
keras-openvino 33.55% <0.00%> (-0.03%) ⬇️
keras-tensorflow 63.89% <0.00%> (-0.05%) ⬇️
keras-torch 63.53% <0.00%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the fix.

@pctablet505
Copy link
Collaborator Author

After further debugging, I can verify that this is not a bug in JAX either.
A100 has cuda compute capability 8.0
In Jax code, they are allowing minimum compute capability 9.0 to support maximum head dimensions of 256.

https://developer.nvidia.com/cuda-gpus#:~:text=8.0,A100%0ANVIDIA%20A30

A100 is Ampere series, compute capability 8.0
we require GPU minimum H series or above, Hopper

https://github.com/jax-ml/jax/blob/jax-v0.6.1/jax/_src/cudnn/fused_attention_stablehlo.py#L380

@pctablet505
Copy link
Collaborator Author

pctablet505 commented May 30, 2025

Can you please tests this on a KerasHub model on GPU and check if flash attention is being called? - Please link the test colab

I've tested it on llama3.2_instruct_1b model and flash attention is working fine for it.

Below is my experimentation setup

Adding print statement in dot_product_attention to check whether flash_attention is True or False

def dot_product_attention(
    query,
    key,
    value,
    bias=None,
    mask=None,
    scale=None,
    is_causal=False,
    flash_attention=None,
    attn_logits_soft_cap=None,
):
    """Computes dot-product attention given query, key, and value.

    This is the core computation of attention that is used in transformers.
    For TPU platforms, flash attention optimizations are automatically applied
    when possible, and sharding parameters are inferred from the layout map
    in the current distribution context.

    Args:
        query: Queries with shape `[batch, time, heads,
            depth_k]`.
        key: Keys with shape `[batch, time, heads,
            depth_k]`.
        value: Values with shape `[batch, time, heads,
            depth_v]`.
        bias: Optional bias with shape broadcastable to
            `[batch, heads, dest_time, source_time]`.
        mask: Optional mask with shape broadcastable to
            `[batch, heads, dest_time, source_time]`.
        scale: Float. Optional scale that is applied to the attention
            computation.
        is_causal: Boolean. Specifying whether causal masking is applied.
        flash_attention: Boolean. Whether to use flash attention optimization
            for increased performance. Default to None, which means it will
            be auto-determined based on the platform, input shapes and
            compatibility.
        attn_logits_soft_cap: Float. Optional float to softly cap attention
            logits to avoid numerical stability issues. Applied as:
            `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.

    Returns:
        JAX Array of shape `[batch, time, heads, depth_v]`.
    """
    query = convert_to_tensor(query)
    key = convert_to_tensor(key)
    value = convert_to_tensor(value)
    if len(query.shape) != 4 or len(key.shape) != 4 or len(value.shape) != 4:
        raise ValueError(
            "`dot_product_attention` only supports 4D inputs. "
            f"Received: query.shape={query.shape}, key.shape={key.shape}, "
            f"value.shape={value.shape}."
        )

    # Check platform
    platform = jax.devices()[0].platform
    is_tpu = platform == "tpu"

    # Determine flash attention compatibility
    if flash_attention is None:
        flash_attention = _can_use_flash_attention(query, key, value, bias)
    elif flash_attention is True:
        # Use `raise_error=True` to provide more details if the inputs failed to
        # use flash attention
        _can_use_flash_attention(query, key, value, bias, raise_error=True)
    print('flash_attention: ',flash_attention)
    # TPU-specific flash attention path
    if is_tpu and flash_attention:
        # Get sharding parameters from distribution context
        try:
            from keras.src.distribution.distribution_lib import ModelParallel
            from keras.src.distribution.distribution_lib import (
                distribution as get_dist,
            )

            # Get current distribution if available
            dist = get_dist()
            if dist and isinstance(dist, ModelParallel):
                mesh = dist.device_mesh
                if "model" in mesh.axis_names:
                    model_dim_index = mesh.axis_names.index("model")
                    # Set head_shards based on the model dimension of the mesh
                    head_shards = mesh.shape[model_dim_index]
                    # Typically keep q_seq_shards=1 for best performance
                    q_seq_shards = 1
        except (ImportError, ValueError, AttributeError):
            # Use default values if detection fails
            head_shards = 1
            q_seq_shards = 1
        # Transpose to ('batch', 'heads', 'length', 'head_dim')
        query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
        key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
        value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))

        bs, num_heads, q_len, head_dim = query_tpu_layout.shape

        # Apply scale to query if provided
        if scale is not None:
            # TPU kernel applies 1/sqrt(head_dim) internally, to achieve
            # overall QK^T * scale, scale query by (scale * sqrt(head_dim))
            query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))

        # Create segment IDs for Splash Attention (for packing/batching)
        segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
        decoder_segment_ids = splash_attention_kernel.SegmentIds(
            q=segment_ids, kv=segment_ids
        )

        # Process mask for Splash Attention
        custom_mask = None
        if mask is not None:
            mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask

            if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
                custom_mask = mask_bool[0]
            elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
                custom_mask = mask_bool[0, 0]

            if is_causal and custom_mask is not None:
                causal_mask = jnp.tril(
                    jnp.ones((q_len, q_len), dtype=jnp.bool_)
                )
                custom_mask = jnp.logical_and(custom_mask, causal_mask)

        if custom_mask is None and is_causal:
            custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))

        try:
            output = wrap_flash_attention(
                query_tpu_layout,
                key_tpu_layout,
                value_tpu_layout,
                decoder_segment_ids=decoder_segment_ids,
                custom_mask=custom_mask,
                attn_logits_soft_cap=attn_logits_soft_cap,
                head_shards=head_shards,
                q_seq_shards=q_seq_shards,
            )
            # Transpose output back to Keras layout
            return jnp.transpose(output, axes=(0, 2, 1, 3))
        except Exception:
            flash_attention = False

    # JAX native dot_product_attention for GPU or fallback for TPU
    if hasattr(jax.nn, "dot_product_attention"):
        try:
            return jax.nn.dot_product_attention(
                query,
                key,
                value,
                bias=bias,
                mask=mask,
                scale=scale,
                is_causal=is_causal,
                implementation="cudnn" if flash_attention else "xla",
            )
        except Exception:
            # If flash attention fails, fall back to XLA implementation
            if flash_attention:
                return jax.nn.dot_product_attention(
                    query,
                    key,
                    value,
                    bias=bias,
                    mask=mask,
                    scale=scale,
                    is_causal=is_causal,
                    implementation="xla",
                )
            raise

    if flash_attention:
        raise RuntimeError(
            "Flash attention is not supported in your current JAX version. "
            "Please update it by following the official guide: "
            "https://jax.readthedocs.io/en/latest/installation.html"
        )
    # Ref: jax.nn.dot_product_attention
    # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
    # Not support `query_seq_lengths` and `key_value_seq_lengths` args

    # Fallback to custom XLA implementation
    # This is the reference implementation from jax.nn.dot_product_attention
    output_shape = query.shape
    _, _, K, H = key.shape
    scale = (1.0 / jnp.sqrt(H)) if scale is None else scale

    # _dot_product_attention_xla
    B, T, N, H = query.shape
    G = N // K
    query = jnp.reshape(query, (B, T, K, G, H))

    def _reshape_to_grouped(t):
        if t is not None:
            tB, tN, tT, tS = t.shape
            if tN == 1:
                t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
            else:
                assert tN == N
                t = jnp.reshape(t, (tB, K, G, tT, tS))
        return t

    bias = _reshape_to_grouped(bias)
    mask = _reshape_to_grouped(mask)
    vmapped_fn = jax.vmap(
        _dot_product_attention_core,
        in_axes=(3, None, None, 2, 2, None, None),
        out_axes=3,
    )
    encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
    return jnp.reshape(encoded, output_shape)

Adding debug statements in _can_use_flash_attention

def _can_use_flash_attention(query, key, value, bias, raise_error=False):
    """Verify the availability of flash attention."""
    try:
        from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout
        from jax._src.cudnn.fused_attention_stablehlo import (
            check_compute_capability,
        )
        from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
        from jax._src.cudnn.fused_attention_stablehlo import (
            check_is_flash_attention,
        )
        from jax._src.cudnn.fused_attention_stablehlo import check_layout
        from jax.nn import dot_product_attention as dot_product_attention
    except ImportError:
        if raise_error:
            raise ImportError(
                "Flash attention is not supported in your current JAX version. "
                "Please update it by following the official guide: "
                "https://jax.readthedocs.io/en/latest/installation.html"
            )
        return False

    if jax.devices()[0].platform == "tpu":
        return True
    try:
        # Check if cuDNN is installed and raise RuntimeError if cuDNN is not
        # detected
        cudnn_version = check_cudnn_version()
        # Only support at least Ampere
        if not check_compute_capability("8.0"):
            raise RuntimeError("Require at least Ampere arch to run")
        # Check inputs layout
        check_layout(
            query,
            key,
            value,
            bias,
            q_seqlen=None,
            kv_seqlen=None,
            layout=_normalize_layout("BTNH"),
            q_offsets=None,
            kv_offsets=None,
        )
        print('check_layout: passed')
        print('cudnn_version: ',cudnn_version)
        check_is_flash_attention(
            query,
            key,
            _normalize_layout("BTNH"),
            cudnn_version,
            bias is not None,
            is_training=False,
        )
        print('check_is_flash_attention: passed')
        return True
    except Exception as e:
        if raise_error:
            raise e
        return False

And adding debug statements in jax code

def check_is_flash_attention(
    query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False,
    is_fp8=False):
    # Extract sequence length (T) and head dim (H) based on layout
    if layout == AttentionLayout.BNTH.value:
        _, _, T, H = query.shape
        _, _, S, _ = key.shape
    else:
        _, T, _, H = query.shape
        _, S, _, _ = key.shape

    # Flash attention conditions
    if is_fp8:
        # FP8 specific conditions
        if not ((is_training and H == 128 and T % 128 == 0 and S % 128 == 0) or
                (not is_training and H <= 256 and H % 16 == 0)):
            raise NotImplementedError(
                f"Unsupported sequence length Q {T}, KV {S} and head dim {H} for FP8."
            )
    else:
        # bf16/fp16 attention conditions
        # Check the head dim.
        print('is_cuda_compute_capabality: ',is_cuda_compute_capability_equal("9.0"))
        is_on_hopper = is_cuda_compute_capability_equal("9.0")
        print('cudnn_version: ',cudnn_version)
        H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
        if not (H <= H_max and H % 8 == 0):
          print('if block 1')
          raise NotImplementedError(
              f"The head dim must be <= {H_max} and a mutiple of 8, "
              f"but got {H}."
          )

        # Check patterns with bias, seqlen should be divisible by 2
        if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
          print('if block 2')
          raise NotImplementedError(
              f"Unsupported sequence length Q {T}, KV {S}."
          )

        if is_packed and (cudnn_version < 90600 or not check_compute_capability("9.0")):
          print('if block 3')
          raise NotImplementedError(
            "Packed layout requires cudnn version >= 9.6 and at least hopper arch.")

def check_cudnn_version():
  # check if cuDNN is installed
  if cuda_versions is None:
    raise RuntimeError("cuDNN is not detected.")
  return cuda_versions.cudnn_get_version()

def check_compute_capability(capability):
  print('check_compute_capability xla backend version', xla_bridge.get_backend().platform_version)
  if not 'cuda' in xla_bridge.get_backend().platform_version:
    return False
  d, *_ = jax.local_devices(backend="gpu")
  target = tuple(int(x) for x in capability.split("."))
  current = tuple(int(x) for x in d.compute_capability.split("."))
  print('current, target, current>target',current, target,current>target)

  return current >= target

@pctablet505
Copy link
Collaborator Author

For Llama3.2 I'm getting the below output

heck_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
check_is_flash_attention: passed
flash_attention:  True
is_cuda_compute_capabality:  False
cudnn_version:  91001
check_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
check_is_flash_attention: passed
flash_attention:  True
is_cuda_compute_capabality:  False
cudnn_version:  91001
check_compute_capability xla backend version PJRT C API

this verifies flash_attention is working

@pctablet505
Copy link
Collaborator Author

pctablet505 commented May 30, 2025

whereas for Gemma3 I get the below output

check_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
if block 1
flash_attention:  False
check_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
if block 1
flash_attention:  False
check_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
if block 1
flash_attention:  False
check_compute_capability xla backend version PJRT C API
cuda 12080
current, target, current>target (8, 0) (8, 0) False
check_layout: passed
cudnn_version:  91001
is_cuda_compute_capabality:  False
cudnn_version:  91001
if block 1
flash_attention:  False```

@pctablet505
Copy link
Collaborator Author

@divyashreepathihalli
I've added the justifications, why gemma3 doesn't support flash_attention on A100 GPU. It is not a bug, neither in Keras, nor in JAX

Please go through above comments.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Reapplies the original TPU fix for dot_product_attention, updates to match the latest JAX API signature, and improves flash-attention error handling and documentation.

  • Updates _can_use_flash_attention to pass the new q_offsets/kv_offsets to check_layout.
  • Adds named exception catches to preserve error context and improves diagnostic messages.
  • Reintroduces and extends the TPU‐optimized flash-attention path with sharding support and explanatory docstrings.
Comments suppressed due to low confidence (1)

keras/src/backend/jax/nn.py:1262

  • The math module isn’t imported in this file, so math.sqrt will raise a NameError. Add import math at the top of the module.
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))

@divyashreepathihalli
Copy link
Collaborator

Hi Rahul!
Can you please confirm FA is working for Gemma 1 and 2 and Stable diffusion on GPU and TPU like it was before?

@pctablet505
Copy link
Collaborator Author

pctablet505 commented Jun 5, 2025

Hi Rahul! Can you please confirm FA is working for Gemma 1 and 2 and Stable diffusion on GPU and TPU like it was before?

Hi @divyashreepathihalli,
Below is my test result after verifying it.

Test Results

Model GPU (A100) TPU (v28)
Gemma 2B False True
Gemma2 2B NA (no print statements working) True
Gemma3 1B False True
LLama3.2 1B True True
Stable Diffusion 3 Medium True True
Stable Diffusion 3.5 Large True NA (Can't test due to lack of Memory)

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jun 7, 2025
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 7, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jun 10, 2025
@fchollet fchollet merged commit 0c0ec1a into keras-team:master Jun 10, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants